crfm-helm 0.5.7__py3-none-any.whl → 0.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (333) hide show
  1. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/METADATA +7 -77
  2. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/RECORD +315 -282
  3. helm/benchmark/adaptation/adapter_spec.py +10 -0
  4. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +11 -3
  5. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +11 -8
  6. helm/benchmark/annotation/aci_bench_annotator.py +11 -22
  7. helm/benchmark/annotation/alrage_annotator.py +90 -0
  8. helm/benchmark/annotation/chw_care_plan_annotator.py +10 -21
  9. helm/benchmark/annotation/dischargeme_annotator.py +11 -22
  10. helm/benchmark/annotation/med_dialog_annotator.py +11 -22
  11. helm/benchmark/annotation/medalign_annotator.py +11 -22
  12. helm/benchmark/annotation/medi_qa_annotator.py +11 -22
  13. helm/benchmark/annotation/medication_qa_annotator.py +11 -22
  14. helm/benchmark/annotation/mental_health_annotator.py +11 -22
  15. helm/benchmark/annotation/mimic_bhc_annotator.py +11 -22
  16. helm/benchmark/annotation/mimic_rrs_annotator.py +11 -22
  17. helm/benchmark/annotation/model_as_judge.py +23 -18
  18. helm/benchmark/annotation/mtsamples_procedures_annotator.py +11 -22
  19. helm/benchmark/annotation/mtsamples_replicate_annotator.py +11 -22
  20. helm/benchmark/annotation/starr_patient_instructions_annotator.py +11 -22
  21. helm/benchmark/metrics/air_bench_metrics.py +3157 -1
  22. helm/benchmark/metrics/alrage_metric.py +35 -0
  23. helm/benchmark/metrics/basic_metrics.py +267 -2
  24. helm/benchmark/metrics/bbq_metrics.py +12 -0
  25. helm/benchmark/metrics/classification_metrics.py +19 -1
  26. helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +12 -1
  27. helm/benchmark/metrics/dry_run_metrics.py +30 -1
  28. helm/benchmark/metrics/efficiency_metrics.py +74 -0
  29. helm/benchmark/metrics/ehr_sql_metrics.py +57 -1
  30. helm/benchmark/metrics/evaluate_reference_metrics.py +311 -0
  31. helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +13 -1
  32. helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +13 -1
  33. helm/benchmark/metrics/ifeval_metrics.py +13 -1
  34. helm/benchmark/metrics/instruction_following_critique_metrics.py +41 -1
  35. helm/benchmark/metrics/kpi_edgar_metrics.py +21 -0
  36. helm/benchmark/metrics/language_modeling_metrics.py +13 -1
  37. helm/benchmark/metrics/live_qa_metrics.py +13 -1
  38. helm/benchmark/metrics/llm_jury_metrics.py +13 -1
  39. helm/benchmark/metrics/medcalc_bench_metrics.py +14 -1
  40. helm/benchmark/metrics/medec_metrics.py +25 -2
  41. helm/benchmark/metrics/metric.py +25 -0
  42. helm/benchmark/metrics/mimiciv_billing_code_metrics.py +32 -1
  43. helm/benchmark/metrics/omni_math_metrics.py +13 -1
  44. helm/benchmark/metrics/safety_metrics.py +13 -1
  45. helm/benchmark/metrics/seahelm_metrics.py +14 -1
  46. helm/benchmark/metrics/summac/model_summac.py +2 -2
  47. helm/benchmark/metrics/summarization_metrics.py +129 -1
  48. helm/benchmark/metrics/toxicity_metrics.py +31 -1
  49. helm/benchmark/metrics/ultra_suite_asr_classification_metrics.py +52 -0
  50. helm/benchmark/metrics/wildbench_metrics.py +21 -1
  51. helm/benchmark/presentation/run_display.py +13 -3
  52. helm/benchmark/presentation/run_entry.py +2 -2
  53. helm/benchmark/presentation/schema.py +5 -22
  54. helm/benchmark/presentation/summarize.py +180 -11
  55. helm/benchmark/presentation/taxonomy_info.py +20 -0
  56. helm/benchmark/run.py +1 -1
  57. helm/benchmark/run_expander.py +4 -0
  58. helm/benchmark/run_specs/arabic_run_specs.py +140 -16
  59. helm/benchmark/run_specs/bluex_run_specs.py +1 -1
  60. helm/benchmark/run_specs/classic_run_specs.py +2 -2
  61. helm/benchmark/run_specs/long_context_run_specs.py +2 -2
  62. helm/benchmark/run_specs/medhelm/__init__.py +0 -0
  63. helm/benchmark/run_specs/medhelm/benchmark_config.py +219 -0
  64. helm/benchmark/run_specs/medhelm_run_specs.py +362 -52
  65. helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +6 -2
  66. helm/benchmark/scenarios/aci_bench_scenario.py +23 -0
  67. helm/benchmark/scenarios/air_bench_scenario.py +21 -0
  68. helm/benchmark/scenarios/alrage_scenario.py +54 -0
  69. helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +23 -1
  70. helm/benchmark/scenarios/anthropic_red_team_scenario.py +12 -1
  71. helm/benchmark/scenarios/arabic_exams_scenario.py +114 -0
  72. helm/benchmark/scenarios/arabic_mmlu_scenario.py +8 -4
  73. helm/benchmark/scenarios/aratrust_scenario.py +19 -0
  74. helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification_scenario.py +24 -54
  75. helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +19 -48
  76. helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +22 -61
  77. helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +21 -29
  78. helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +21 -60
  79. helm/benchmark/scenarios/babi_qa_scenario.py +15 -0
  80. helm/benchmark/scenarios/banking77_scenario.py +21 -0
  81. helm/benchmark/scenarios/bbq_scenario.py +15 -0
  82. helm/benchmark/scenarios/best_chatgpt_prompts.yaml +473 -0
  83. helm/benchmark/scenarios/bird_sql_scenario.py +18 -0
  84. helm/benchmark/scenarios/bluex_scenario.py +6 -2
  85. helm/benchmark/scenarios/bold_scenario.py +15 -0
  86. helm/benchmark/scenarios/boolq_scenario.py +20 -0
  87. helm/benchmark/scenarios/chw_care_plan_scenario.py +23 -0
  88. helm/benchmark/scenarios/civil_comments_scenario.py +13 -0
  89. helm/benchmark/scenarios/clear_scenario.py +23 -0
  90. helm/benchmark/scenarios/cleva_scenario.py +479 -0
  91. helm/benchmark/scenarios/code_scenario.py +28 -0
  92. helm/benchmark/scenarios/commonsense_scenario.py +32 -0
  93. helm/benchmark/scenarios/compositional_instructions.yaml +70 -0
  94. helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +21 -0
  95. helm/benchmark/scenarios/copyright_scenario.py +35 -1
  96. helm/benchmark/scenarios/cti_to_mitre_scenario.py +21 -0
  97. helm/benchmark/scenarios/czech_bank_qa_scenario.py +18 -0
  98. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +22 -1
  99. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +23 -1
  100. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +22 -1
  101. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +21 -1
  102. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +13 -0
  103. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +13 -1
  104. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +13 -1
  105. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +13 -1
  106. helm/benchmark/scenarios/dischargeme_scenario.py +24 -0
  107. helm/benchmark/scenarios/disinformation_scenario.py +22 -0
  108. helm/benchmark/scenarios/dyck_language_scenario.py +15 -0
  109. helm/benchmark/scenarios/ehrshot_scenario.py +22 -0
  110. helm/benchmark/scenarios/enem_challenge_scenario.py +19 -0
  111. helm/benchmark/scenarios/entity_data_imputation_scenario.py +14 -0
  112. helm/benchmark/scenarios/entity_matching_scenario.py +14 -0
  113. helm/benchmark/scenarios/fin_qa_scenario.py +20 -0
  114. helm/benchmark/scenarios/financebench_scenario.py +21 -0
  115. helm/benchmark/scenarios/financial_phrasebank_scenario.py +21 -0
  116. helm/benchmark/scenarios/gold_commodity_news_scenario.py +21 -0
  117. helm/benchmark/scenarios/gpqa_scenario.py +18 -0
  118. helm/benchmark/scenarios/grammar_scenario.py +20 -1
  119. helm/benchmark/scenarios/gsm_scenario.py +21 -0
  120. helm/benchmark/scenarios/harm_bench_gcg_transfer_scenario.py +12 -1
  121. helm/benchmark/scenarios/harm_bench_scenario.py +12 -1
  122. helm/benchmark/scenarios/headqa_scenario.py +22 -0
  123. helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +13 -0
  124. helm/benchmark/scenarios/ice_scenario.py +21 -1
  125. helm/benchmark/scenarios/ifeval_scenario.py +18 -0
  126. helm/benchmark/scenarios/imdb_scenario.py +15 -0
  127. helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +21 -0
  128. helm/benchmark/scenarios/infinite_bench_en_sum_scenario.py +19 -0
  129. helm/benchmark/scenarios/koala_scenario.py +21 -1
  130. helm/benchmark/scenarios/kpi_edgar_scenario.py +21 -0
  131. helm/benchmark/scenarios/legal_contract_summarization_scenario.py +20 -0
  132. helm/benchmark/scenarios/legal_summarization_scenario.py +50 -0
  133. helm/benchmark/scenarios/legal_support_scenario.py +13 -0
  134. helm/benchmark/scenarios/legalbench_scenario.py +19 -0
  135. helm/benchmark/scenarios/lex_glue_scenario.py +11 -0
  136. helm/benchmark/scenarios/lextreme_scenario.py +11 -0
  137. helm/benchmark/scenarios/lsat_qa_scenario.py +14 -0
  138. helm/benchmark/scenarios/madinah_qa_scenario.py +73 -0
  139. helm/benchmark/scenarios/math_scenario.py +33 -0
  140. helm/benchmark/scenarios/mbzuai_human_translated_arabic_mmlu.py +68 -0
  141. helm/benchmark/scenarios/med_dialog_scenario.py +32 -1
  142. helm/benchmark/scenarios/med_mcqa_scenario.py +14 -0
  143. helm/benchmark/scenarios/med_qa_scenario.py +20 -0
  144. helm/benchmark/scenarios/medalign_scenario.py +23 -0
  145. helm/benchmark/scenarios/medbullets_scenario.py +22 -0
  146. helm/benchmark/scenarios/medcalc_bench_scenario.py +22 -0
  147. helm/benchmark/scenarios/medec_scenario.py +23 -0
  148. helm/benchmark/scenarios/medhallu_scenario.py +23 -0
  149. helm/benchmark/scenarios/medhelm/__init__.py +0 -0
  150. helm/benchmark/scenarios/medhelm/judges.yaml +14 -0
  151. helm/benchmark/scenarios/medhelm_configurable_scenario.py +101 -0
  152. helm/benchmark/scenarios/medi_qa_scenario.py +24 -1
  153. helm/benchmark/scenarios/medication_qa_scenario.py +31 -1
  154. helm/benchmark/scenarios/mental_health_scenario.py +23 -0
  155. helm/benchmark/scenarios/mimic_bhc_scenario.py +24 -0
  156. helm/benchmark/scenarios/mimic_rrs_scenario.py +23 -0
  157. helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +22 -0
  158. helm/benchmark/scenarios/mmlu_pro_scenario.py +18 -0
  159. helm/benchmark/scenarios/mmlu_scenario.py +21 -0
  160. helm/benchmark/scenarios/msmarco_scenario.py +30 -0
  161. helm/benchmark/scenarios/mtsamples_procedures_scenario.py +22 -0
  162. helm/benchmark/scenarios/mtsamples_replicate_scenario.py +22 -0
  163. helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +20 -0
  164. helm/benchmark/scenarios/narrativeqa_scenario.py +19 -0
  165. helm/benchmark/scenarios/natural_qa_scenario.py +32 -0
  166. helm/benchmark/scenarios/omni_math_scenario.py +18 -0
  167. helm/benchmark/scenarios/open_assistant_scenario.py +22 -0
  168. helm/benchmark/scenarios/openai_mrcr_scenario.py +15 -0
  169. helm/benchmark/scenarios/pubmed_qa_scenario.py +22 -0
  170. helm/benchmark/scenarios/quac_scenario.py +14 -0
  171. helm/benchmark/scenarios/race_based_med_scenario.py +23 -0
  172. helm/benchmark/scenarios/raft_scenario.py +15 -0
  173. helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +14 -1
  174. helm/benchmark/scenarios/ruler_qa_scenarios.py +40 -0
  175. helm/benchmark/scenarios/scenario.py +31 -0
  176. helm/benchmark/scenarios/seahelm_scenario.py +348 -0
  177. helm/benchmark/scenarios/self_instruct_scenario.py +29 -1
  178. helm/benchmark/scenarios/shc_bmt_scenario.py +22 -0
  179. helm/benchmark/scenarios/shc_cdi_scenario.py +20 -0
  180. helm/benchmark/scenarios/shc_conf_scenario.py +23 -0
  181. helm/benchmark/scenarios/shc_ent_scenario.py +21 -0
  182. helm/benchmark/scenarios/shc_gip_scenario.py +20 -0
  183. helm/benchmark/scenarios/shc_privacy_scenario.py +22 -0
  184. helm/benchmark/scenarios/shc_proxy_scenario.py +22 -0
  185. helm/benchmark/scenarios/shc_ptbm_scenario.py +23 -0
  186. helm/benchmark/scenarios/shc_sequoia_scenario.py +21 -0
  187. helm/benchmark/scenarios/simple_safety_tests_scenario.py +12 -1
  188. helm/benchmark/scenarios/situation_prompts.yaml +49 -0
  189. helm/benchmark/scenarios/spider_scenario.py +18 -0
  190. helm/benchmark/scenarios/starr_patient_instructions_scenario.py +22 -0
  191. helm/benchmark/scenarios/summarization_scenario.py +37 -0
  192. helm/benchmark/scenarios/synthetic_efficiency_scenario.py +22 -1
  193. helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +13 -0
  194. helm/benchmark/scenarios/test_alrage_scenario.py +23 -0
  195. helm/benchmark/scenarios/test_arabic_exams_scenario.py +21 -0
  196. helm/benchmark/scenarios/test_aratrust_scenario.py +1 -1
  197. helm/benchmark/scenarios/test_bluex_scenario.py +2 -2
  198. helm/benchmark/scenarios/thai_exam_scenario.py +95 -0
  199. helm/benchmark/scenarios/the_pile_scenario.py +13 -1
  200. helm/benchmark/scenarios/truthful_qa_scenario.py +14 -0
  201. helm/benchmark/scenarios/twitter_aae_scenario.py +20 -1
  202. helm/benchmark/scenarios/vicuna_scenario.py +21 -1
  203. helm/benchmark/scenarios/wikifact_scenario.py +20 -0
  204. helm/benchmark/scenarios/wildbench_scenario.py +18 -0
  205. helm/benchmark/scenarios/wmt_14_scenario.py +19 -0
  206. helm/benchmark/static/schema_arabic.yaml +55 -12
  207. helm/benchmark/static/schema_long_context.yaml +11 -30
  208. helm/benchmark/static/schema_medhelm.yaml +36 -0
  209. helm/benchmark/static/schema_slp.yaml +219 -0
  210. helm/benchmark/static_build/assets/audio-table-Dn5NMMeJ.png +0 -0
  211. helm/benchmark/static_build/assets/index-oIeiQW2g.css +1 -0
  212. helm/benchmark/static_build/assets/index-qOFpOyHb.js +10 -0
  213. helm/benchmark/static_build/assets/react-BteFIppM.js +85 -0
  214. helm/benchmark/static_build/assets/recharts-DxuQtTOs.js +97 -0
  215. helm/benchmark/static_build/assets/tremor-DR4fE7ko.js +10 -0
  216. helm/benchmark/static_build/index.html +5 -6
  217. helm/clients/ai21_client.py +2 -0
  218. helm/clients/aleph_alpha_client.py +2 -0
  219. helm/clients/anthropic_client.py +7 -1
  220. helm/clients/audio_language/diva_llama_client.py +2 -0
  221. helm/clients/audio_language/llama_omni/arguments.py +61 -0
  222. helm/clients/audio_language/llama_omni/constants.py +9 -0
  223. helm/clients/audio_language/llama_omni/conversation.py +213 -0
  224. helm/clients/audio_language/llama_omni/model/__init__.py +0 -0
  225. helm/clients/audio_language/llama_omni/model/builder.py +88 -0
  226. helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py +190 -0
  227. helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py +118 -0
  228. helm/clients/audio_language/llama_omni/model/omni_speech_arch.py +249 -0
  229. helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py +9 -0
  230. helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py +27 -0
  231. helm/clients/audio_language/llama_omni/model/speech_generator/builder.py +9 -0
  232. helm/clients/audio_language/llama_omni/model/speech_generator/generation.py +622 -0
  233. helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py +104 -0
  234. helm/clients/audio_language/llama_omni/model/speech_projector/builder.py +9 -0
  235. helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py +27 -0
  236. helm/clients/audio_language/llama_omni/preprocess.py +295 -0
  237. helm/clients/audio_language/llama_omni/utils.py +202 -0
  238. helm/clients/audio_language/llama_omni_client.py +2 -1
  239. helm/clients/audio_language/qwen2_5_omni_client.py +2 -1
  240. helm/clients/audio_language/qwen2_audiolm_client.py +2 -1
  241. helm/clients/audio_language/qwen_audiolm_client.py +2 -1
  242. helm/clients/audio_language/qwen_omni/configuration_qwen2_5_omni.py +519 -0
  243. helm/clients/audio_language/qwen_omni/modeling_qwen2_5_omni.py +4308 -0
  244. helm/clients/audio_language/qwen_omni/processing_qwen2_5_omni.py +270 -0
  245. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/__init__.py +0 -0
  246. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/__init__.py +8 -0
  247. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/audio_process.py +56 -0
  248. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/vision_process.py +380 -0
  249. helm/clients/bedrock_client.py +2 -0
  250. helm/clients/cohere_client.py +3 -0
  251. helm/clients/google_client.py +2 -0
  252. helm/clients/http_model_client.py +2 -0
  253. helm/clients/huggingface_client.py +2 -1
  254. helm/clients/ibm_client.py +3 -1
  255. helm/clients/image_generation/adobe_vision_client.py +2 -0
  256. helm/clients/image_generation/aleph_alpha_image_generation_client.py +2 -0
  257. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +1 -1
  258. helm/clients/image_generation/cogview2_client.py +2 -1
  259. helm/clients/image_generation/dalle2_client.py +2 -0
  260. helm/clients/image_generation/dalle_mini_client.py +2 -1
  261. helm/clients/image_generation/deep_floyd_client.py +2 -0
  262. helm/clients/image_generation/huggingface_diffusers_client.py +2 -1
  263. helm/clients/image_generation/lexica_client.py +2 -0
  264. helm/clients/image_generation/mindalle/models/stage1/layers.py +2 -2
  265. helm/clients/image_generation/mindalle_client.py +2 -1
  266. helm/clients/image_generation/together_image_generation_client.py +2 -0
  267. helm/clients/megatron_client.py +2 -0
  268. helm/clients/mistral_client.py +2 -0
  269. helm/clients/moderation_api_client.py +2 -0
  270. helm/clients/openai_client.py +36 -20
  271. helm/clients/openai_responses_client.py +27 -3
  272. helm/clients/openrouter_client.py +31 -0
  273. helm/clients/palmyra_client.py +2 -1
  274. helm/clients/reka_client.py +2 -1
  275. helm/clients/stanfordhealthcare_azure_openai_client.py +2 -2
  276. helm/clients/stanfordhealthcare_http_model_client.py +2 -0
  277. helm/clients/test_openrouter_client.py +69 -0
  278. helm/clients/together_client.py +52 -11
  279. helm/clients/vertexai_client.py +12 -2
  280. helm/clients/vision_language/huggingface_vision2seq_client.py +2 -1
  281. helm/clients/vision_language/huggingface_vlm_client.py +2 -0
  282. helm/clients/vision_language/idefics_client.py +2 -1
  283. helm/clients/vision_language/open_flamingo_client.py +2 -1
  284. helm/clients/vision_language/paligemma_client.py +2 -1
  285. helm/clients/vision_language/palmyra_vision_client.py +2 -0
  286. helm/clients/vision_language/qwen2_vlm_client.py +2 -1
  287. helm/clients/vision_language/qwen_vlm_client.py +2 -1
  288. helm/clients/writer_client.py +2 -0
  289. helm/common/hierarchical_logger.py +20 -0
  290. helm/common/optional_dependencies.py +1 -1
  291. helm/common/test_general.py +4 -0
  292. helm/config/model_deployments.yaml +300 -1
  293. helm/config/model_metadata.yaml +302 -9
  294. helm/config/tokenizer_configs.yaml +92 -4
  295. helm/proxy/example_queries.py +8 -8
  296. helm/proxy/server.py +2 -1
  297. helm/proxy/static/index.css +4 -0
  298. helm/proxy/static/index.js +7 -1
  299. helm/benchmark/metrics/aci_bench_metrics.py +0 -14
  300. helm/benchmark/metrics/chw_care_plan_metrics.py +0 -14
  301. helm/benchmark/metrics/dischargeme_metrics.py +0 -14
  302. helm/benchmark/metrics/med_dialog_metrics.py +0 -14
  303. helm/benchmark/metrics/medalign_metrics.py +0 -14
  304. helm/benchmark/metrics/medi_qa_metrics.py +0 -14
  305. helm/benchmark/metrics/medication_qa_metrics.py +0 -14
  306. helm/benchmark/metrics/mental_health_metrics.py +0 -14
  307. helm/benchmark/metrics/mimic_bhc_metrics.py +0 -14
  308. helm/benchmark/metrics/mimic_rrs_metrics.py +0 -14
  309. helm/benchmark/metrics/mtsamples_procedures_metrics.py +0 -14
  310. helm/benchmark/metrics/mtsamples_replicate_metrics.py +0 -14
  311. helm/benchmark/metrics/starr_patient_instructions_metrics.py +0 -14
  312. helm/benchmark/static_build/assets/index-b9779128.css +0 -1
  313. helm/benchmark/static_build/assets/index-e439d5e1.js +0 -10
  314. helm/benchmark/static_build/assets/react-f82877fd.js +0 -85
  315. helm/benchmark/static_build/assets/recharts-4037aff0.js +0 -97
  316. helm/benchmark/static_build/assets/tremor-38a10867.js +0 -10
  317. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/WHEEL +0 -0
  318. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/entry_points.txt +0 -0
  319. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/licenses/LICENSE +0 -0
  320. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/top_level.txt +0 -0
  321. /helm/benchmark/static_build/assets/{air-overview-d2e6c49f.png → air-overview-DpBbyagA.png} +0 -0
  322. /helm/benchmark/static_build/assets/{crfm-logo-74391ab8.png → crfm-logo-Du4T1uWZ.png} +0 -0
  323. /helm/benchmark/static_build/assets/{heim-logo-3e5e3aa4.png → heim-logo-BJtQlEbV.png} +0 -0
  324. /helm/benchmark/static_build/assets/{helm-logo-simple-2ed5400b.png → helm-logo-simple-DzOhNN41.png} +0 -0
  325. /helm/benchmark/static_build/assets/{helm-safety-2907a7b6.png → helm-safety-COfndXuS.png} +0 -0
  326. /helm/benchmark/static_build/assets/{helmhero-28e90f4d.png → helmhero-D9TvmJsp.png} +0 -0
  327. /helm/benchmark/static_build/assets/{medhelm-overview-eac29843.png → medhelm-overview-CND0EIsy.png} +0 -0
  328. /helm/benchmark/static_build/assets/{medhelm-v1-overview-3ddfcd65.png → medhelm-v1-overview-Cu2tphBB.png} +0 -0
  329. /helm/benchmark/static_build/assets/{overview-74aea3d8.png → overview-BwypNWnk.png} +0 -0
  330. /helm/benchmark/static_build/assets/{process-flow-bd2eba96.png → process-flow-DWDJC733.png} +0 -0
  331. /helm/benchmark/static_build/assets/{vhelm-aspects-1437d673.png → vhelm-aspects-NiDQofvP.png} +0 -0
  332. /helm/benchmark/static_build/assets/{vhelm-framework-a1ca3f3f.png → vhelm-framework-NxJE4fdA.png} +0 -0
  333. /helm/benchmark/static_build/assets/{vhelm-model-8afb7616.png → vhelm-model-ypCL5Yvq.png} +0 -0
@@ -0,0 +1,4308 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_qwen2_5_omni.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+ import math
24
+ import operator
25
+ from dataclasses import dataclass
26
+ from itertools import accumulate
27
+ from typing import Any, Dict, List, Optional, Tuple, Union, cast
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ from torch.nn import ConvTranspose1d, Parameter
34
+ from helm.common.optional_dependencies import handle_module_not_found_error
35
+
36
+ from transformers.activations import ACT2FN
37
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, SlidingWindowCache, StaticCache
38
+ from transformers.generation import GenerationMixin
39
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
40
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput
41
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
42
+
43
+ try:
44
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
+ except ModuleNotFoundError as e:
46
+ handle_module_not_found_error(e, ["audiolm"])
47
+ from transformers.utils import (
48
+ add_start_docstrings,
49
+ is_flash_attn_2_available,
50
+ is_flash_attn_greater_or_equal_2_10,
51
+ logging,
52
+ )
53
+ from transformers.utils.hub import cached_file
54
+ from helm.clients.audio_language.qwen_omni.configuration_qwen2_5_omni import (
55
+ Qwen2_5OmniAudioEncoderConfig,
56
+ Qwen2_5OmniBigVGANConfig,
57
+ Qwen2_5OmniConfig,
58
+ Qwen2_5OmniDiTConfig,
59
+ Qwen2_5OmniTalkerConfig,
60
+ Qwen2_5OmniTextConfig,
61
+ Qwen2_5OmniThinkerConfig,
62
+ Qwen2_5OmniToken2WavConfig,
63
+ Qwen2_5OmniVisionEncoderConfig,
64
+ )
65
+
66
+
67
+ if is_flash_attn_2_available():
68
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
69
+ from flash_attn.layers.rotary import apply_rotary_emb
70
+ else:
71
+ flash_attn_varlen_func = None
72
+ apply_rotary_emb = None
73
+
74
+
75
+ if is_flash_attn_2_available():
76
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
77
+ else:
78
+ flash_attn_varlen_func = None
79
+
80
+
81
+ logger = logging.get_logger(__name__)
82
+
83
+
84
+ # @add_start_docstrings(
85
+ # "The bare Qwen2.5Omni Model outputting raw hidden-states without any specific head on top.",
86
+ # QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniConfig"),
87
+ # )
88
+ class Qwen2_5OmniPreTrainedModel(PreTrainedModel):
89
+ config_class: Any = Qwen2_5OmniConfig
90
+ base_model_prefix = "model"
91
+ supports_gradient_checkpointing = True
92
+ _skip_keys_device_placement = "past_key_values"
93
+ _supports_flash_attn_2 = True
94
+ _supports_sdpa = True
95
+ _supports_cache_class = True
96
+ _supports_static_cache = True
97
+
98
+ def _init_weights(self, module):
99
+ # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only
100
+ # inference and fine-tuning - so the proper init weights code has been removed
101
+ std = self.config.init_std if hasattr(self.config, "init_std") else 0.02
102
+
103
+ if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d)):
104
+ module.weight.data.normal_(mean=0.0, std=std)
105
+ if module.bias is not None:
106
+ module.bias.data.zero_()
107
+ elif isinstance(module, nn.Embedding):
108
+ module.weight.data.normal_(mean=0.0, std=std)
109
+ if module.padding_idx is not None:
110
+ module.weight.data[module.padding_idx].zero_()
111
+
112
+
113
+ class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel):
114
+ def _prepare_4d_causal_attention_mask_with_cache_position(
115
+ self,
116
+ attention_mask: torch.Tensor,
117
+ sequence_length: int,
118
+ target_length: int,
119
+ dtype: torch.dtype,
120
+ device: torch.device,
121
+ min_dtype: float,
122
+ cache_position: torch.Tensor,
123
+ batch_size: int,
124
+ ):
125
+ if attention_mask is not None and attention_mask.dim() == 4:
126
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
127
+ causal_mask = attention_mask
128
+ else:
129
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
130
+ if sequence_length != 1:
131
+ causal_mask = torch.triu(causal_mask, diagonal=1)
132
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
133
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
134
+ if attention_mask is not None:
135
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
136
+ mask_length = attention_mask.shape[-1]
137
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
138
+ padding_mask = padding_mask == 0
139
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
140
+ padding_mask, min_dtype
141
+ )
142
+
143
+ return causal_mask
144
+
145
+ def get_input_embeddings(self):
146
+ return self.model.get_input_embeddings()
147
+
148
+ def set_input_embeddings(self, value):
149
+ self.model.set_input_embeddings(value)
150
+
151
+ def get_llm_pos_ids_for_vision(
152
+ self,
153
+ start_idx: int,
154
+ vision_idx: int,
155
+ spatial_merge_size: int,
156
+ t_index: torch.Tensor,
157
+ grid_hs: torch.Tensor,
158
+ grid_ws: torch.Tensor,
159
+ ):
160
+ llm_pos_ids_list = []
161
+ llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
162
+ llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
163
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten()
164
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten()
165
+ t_index_p = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long()
166
+ _llm_pos_ids = torch.stack([t_index_p, h_index, w_index])
167
+ llm_pos_ids_list.append(_llm_pos_ids + start_idx) # + 1 ) # 12.09 by malinhan
168
+ llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
169
+ return llm_pos_ids
170
+
171
+ def get_chunked_index(self, llm_pos_ids, t_ntoken_per_chunk, st_idx):
172
+ def _iter():
173
+ i, start_idx = 0, 0 # skip bos token
174
+ current_chunk = 1
175
+ while i < llm_pos_ids.shape[1]: # skip eos token
176
+ if llm_pos_ids[0][i] - st_idx >= current_chunk * t_ntoken_per_chunk:
177
+ yield (start_idx, i)
178
+ start_idx = i
179
+ current_chunk += 1
180
+ i += 1
181
+ yield (start_idx, llm_pos_ids.shape[1])
182
+
183
+ return list(_iter())
184
+
185
+ def get_rope_index(
186
+ self,
187
+ input_ids: Optional[torch.LongTensor] = None,
188
+ image_grid_thw: Optional[torch.LongTensor] = None,
189
+ video_grid_thw: Optional[torch.LongTensor] = None,
190
+ attention_mask: Optional[torch.Tensor] = None,
191
+ use_audio_in_video: Optional[bool] = False,
192
+ audio_seqlens: Optional[torch.Tensor] = None,
193
+ second_per_grids: Optional[torch.Tensor] = None,
194
+ ):
195
+ spatial_merge_size = self.spatial_merge_size
196
+ image_token_id = self.config.image_token_index
197
+ video_token_id = self.config.video_token_index
198
+ audio_token_id = self.config.audio_token_index
199
+ vision_start_token_id = self.config.vision_start_token_id
200
+ audio_start_token_id = self.config.audio_start_token_id
201
+ position_id_per_seconds = self.config.position_id_per_seconds
202
+ seconds_per_chunk = self.config.seconds_per_chunk
203
+
204
+ mrope_position_deltas = []
205
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
206
+ total_input_ids = input_ids
207
+ if attention_mask is None:
208
+ attention_mask = torch.ones_like(total_input_ids)
209
+ position_ids = torch.ones(
210
+ 3,
211
+ input_ids.shape[0],
212
+ input_ids.shape[1],
213
+ dtype=input_ids.dtype,
214
+ device=input_ids.device,
215
+ )
216
+ image_idx, video_idx, audio_idx = 0, 0, 0
217
+ attention_mask = attention_mask.to(total_input_ids.device)
218
+ for i, input_ids_p in enumerate(total_input_ids):
219
+ input_ids_p = input_ids_p[attention_mask[i] == 1]
220
+ image_nums, video_nums, audio_nums = 0, 0, 0
221
+ vision_start_indices = torch.argwhere(input_ids_p == vision_start_token_id).squeeze(1)
222
+ vision_tokens = input_ids_p[vision_start_indices + 1]
223
+ audio_nums = int(torch.sum(input_ids_p == audio_start_token_id).item())
224
+ image_nums = (vision_tokens == image_token_id).sum()
225
+ video_nums = (
226
+ (vision_tokens == audio_start_token_id).sum()
227
+ if use_audio_in_video
228
+ else (vision_tokens == video_token_id).sum()
229
+ )
230
+ input_tokens = input_ids_p.tolist()
231
+ llm_pos_ids_list: list = []
232
+ st = 0
233
+ remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums
234
+ multimodal_nums = (
235
+ image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums
236
+ )
237
+ for _ in range(multimodal_nums):
238
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
239
+ if image_token_id in input_tokens and remain_images > 0:
240
+ ed_image = input_tokens.index(image_token_id, st)
241
+ else:
242
+ ed_image = len(input_tokens) + 1
243
+ if video_token_id in input_tokens and remain_videos > 0:
244
+ ed_video = input_tokens.index(video_token_id, st)
245
+ else:
246
+ ed_video = len(input_tokens) + 1
247
+ if audio_token_id in input_tokens and remain_audios > 0:
248
+ ed_audio = input_tokens.index(audio_token_id, st)
249
+ else:
250
+ ed_audio = len(input_tokens) + 1
251
+ min_ed = min(ed_image, ed_video, ed_audio)
252
+ if min_ed == ed_audio and audio_seqlens is not None:
253
+ text_len = min_ed - st - 1
254
+ if text_len != 0:
255
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
256
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
257
+
258
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
259
+ bos_len = 1
260
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
261
+
262
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
263
+ audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1
264
+ llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
265
+ llm_pos_ids_list.append(llm_pos_ids)
266
+
267
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
268
+ eos_len = 1
269
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
270
+
271
+ st += text_len + bos_len + audio_len + eos_len
272
+ audio_idx += 1
273
+ remain_audios -= 1
274
+
275
+ elif min_ed == ed_image and image_grid_thw is not None:
276
+ text_len = min_ed - st - 1
277
+ if text_len != 0:
278
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
279
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
280
+
281
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
282
+ bos_len = 1
283
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
284
+
285
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
286
+ grid_t = image_grid_thw[image_idx][0]
287
+ grid_hs = image_grid_thw[:, 1]
288
+ grid_ws = image_grid_thw[:, 2]
289
+ t_index = (torch.arange(grid_t.item()) * 1 * position_id_per_seconds).long()
290
+ llm_pos_ids = self.get_llm_pos_ids_for_vision(
291
+ st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
292
+ )
293
+ image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
294
+ llm_pos_ids_list.append(llm_pos_ids)
295
+
296
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
297
+ eos_len = 1
298
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
299
+
300
+ st += text_len + bos_len + image_len + eos_len
301
+ image_idx += 1
302
+ remain_images -= 1
303
+
304
+ elif (
305
+ min_ed == ed_video
306
+ and not use_audio_in_video
307
+ and video_grid_thw is not None
308
+ and second_per_grids is not None
309
+ ):
310
+ text_len = min_ed - st - 1
311
+ if text_len != 0:
312
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
313
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
314
+
315
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
316
+ bos_len = 1
317
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
318
+
319
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
320
+ grid_t = video_grid_thw[video_idx][0]
321
+ grid_hs = video_grid_thw[:, 1]
322
+ grid_ws = video_grid_thw[:, 2]
323
+ t_index = (
324
+ torch.arange(grid_t.item())
325
+ * second_per_grids[video_idx].cpu().float()
326
+ * position_id_per_seconds
327
+ ).long()
328
+ llm_pos_ids = self.get_llm_pos_ids_for_vision(
329
+ st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
330
+ )
331
+ video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
332
+ llm_pos_ids_list.append(llm_pos_ids)
333
+
334
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
335
+ eos_len = 1
336
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
337
+
338
+ st += text_len + bos_len + video_len + eos_len
339
+ video_idx += 1
340
+ remain_videos -= 1
341
+
342
+ elif (
343
+ min_ed == ed_video
344
+ and use_audio_in_video
345
+ and audio_seqlens is not None
346
+ and video_grid_thw is not None
347
+ and second_per_grids is not None
348
+ ):
349
+ text_len = min_ed - st - 2
350
+ if text_len != 0:
351
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
352
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
353
+
354
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
355
+ bos_len = 1
356
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
357
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
358
+
359
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
360
+ audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1
361
+ audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
362
+ grid_t = video_grid_thw[video_idx][0]
363
+ grid_hs = video_grid_thw[:, 1]
364
+ grid_ws = video_grid_thw[:, 2]
365
+
366
+ t_index = (
367
+ torch.arange(grid_t.item())
368
+ * second_per_grids[video_idx].cpu().float()
369
+ * position_id_per_seconds
370
+ ).long()
371
+ video_llm_pos_ids = self.get_llm_pos_ids_for_vision(
372
+ st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
373
+ )
374
+
375
+ t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk)
376
+ video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids, t_ntoken_per_chunk, st_idx)
377
+ audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids, t_ntoken_per_chunk, st_idx)
378
+ sub_len = 0
379
+ for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
380
+ video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None
381
+ audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None
382
+ if video_chunk_index is not None:
383
+ sub_len += video_chunk_index[1] - video_chunk_index[0]
384
+
385
+ llm_pos_ids_list.append(
386
+ video_llm_pos_ids[:, video_chunk_index[0] : video_chunk_index[1]]
387
+ )
388
+ if audio_chunk_index is not None:
389
+ sub_len += audio_chunk_index[1] - audio_chunk_index[0]
390
+
391
+ llm_pos_ids_list.append(
392
+ audio_llm_pos_ids[:, audio_chunk_index[0] : audio_chunk_index[1]]
393
+ )
394
+ video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
395
+
396
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
397
+ eos_len = 1
398
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
399
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
400
+
401
+ st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2
402
+
403
+ audio_idx += 1
404
+ video_idx += 1
405
+ remain_videos -= 1
406
+ remain_audios -= 1
407
+
408
+ if st < len(input_tokens):
409
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
410
+ text_len = len(input_tokens) - st
411
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
412
+
413
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
414
+
415
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
416
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids_p))
417
+ mrope_position_deltas_p = torch.tensor(mrope_position_deltas, device=input_ids_p.device).unsqueeze(1)
418
+
419
+ return position_ids, mrope_position_deltas_p
420
+ else:
421
+ if attention_mask is not None:
422
+ position_ids = attention_mask.long().cumsum(-1) - 1
423
+ position_ids.masked_fill_(attention_mask == 0, 1)
424
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
425
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
426
+ mrope_position_deltas_p = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
427
+
428
+ return position_ids, mrope_position_deltas_p
429
+
430
+
431
+ @dataclass
432
+ class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput):
433
+
434
+ loss: Optional[torch.FloatTensor] = None
435
+ logits: Optional[torch.FloatTensor] = None
436
+ past_key_values: Optional[List[torch.FloatTensor]] = None
437
+ hidden_states: Optional[Any] = None
438
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
439
+ attention_mask: Optional[torch.Tensor] = None
440
+ rope_deltas: Optional[torch.Tensor] = None
441
+
442
+
443
+ class Qwen2_5OmniAudioAttention(nn.Module):
444
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
445
+
446
+ def __init__(
447
+ self,
448
+ embed_dim: int,
449
+ num_heads: int,
450
+ dropout: float = 0.0,
451
+ is_decoder: bool = False,
452
+ bias: bool = True,
453
+ is_causal: bool = False,
454
+ layer_idx: Optional[int] = None,
455
+ config: Optional[Qwen2_5OmniThinkerConfig] = None,
456
+ ):
457
+ super().__init__()
458
+ self.embed_dim = embed_dim
459
+ self.num_heads = num_heads
460
+ self.dropout = dropout
461
+ self.head_dim = embed_dim // num_heads
462
+ self.config = config
463
+
464
+ if (self.head_dim * num_heads) != self.embed_dim:
465
+ raise ValueError(
466
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
467
+ f" and `num_heads`: {num_heads})."
468
+ )
469
+ self.scaling = self.head_dim**-0.5
470
+ self.is_decoder = is_decoder
471
+ self.is_causal = is_causal
472
+
473
+ if layer_idx is None and is_decoder:
474
+ logger.warning_once(
475
+ f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
476
+ "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
477
+ "when creating this class."
478
+ )
479
+ self.layer_idx = layer_idx
480
+
481
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
482
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
483
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
484
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
485
+
486
+ def forward(
487
+ self,
488
+ hidden_states: torch.Tensor,
489
+ key_value_states: Optional[torch.Tensor] = None,
490
+ past_key_value: Optional[EncoderDecoderCache] = None,
491
+ cu_seqlens: Optional[torch.Tensor] = None,
492
+ layer_head_mask: Optional[torch.Tensor] = None,
493
+ output_attentions: bool = False,
494
+ cache_position: Optional[torch.LongTensor] = None,
495
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
496
+ """Input shape: Batch x Time x Channel"""
497
+
498
+ # if key_value_states are provided this layer is used as a cross-attention layer
499
+ # for the decoder
500
+ is_cross_attention = key_value_states is not None
501
+ seq_length, _ = hidden_states.size()
502
+
503
+ # get query proj
504
+ # query_states = self.q_proj(hidden_states)
505
+ query_states = (hidden_states @ self.q_proj.weight.t()) + self.q_proj.bias
506
+
507
+ query_states = query_states.reshape(seq_length, self.num_heads, -1)
508
+
509
+ if past_key_value is not None:
510
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
511
+ if is_cross_attention:
512
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
513
+ past_key_value.is_updated[self.layer_idx] = True
514
+ past_key_value = past_key_value.cross_attention_cache
515
+ else:
516
+ past_key_value = past_key_value.self_attention_cache
517
+
518
+ # use key_value_states if cross attention
519
+ current_states = key_value_states if key_value_states is not None else hidden_states
520
+ if is_cross_attention and past_key_value and is_updated:
521
+ # reuse k,v, cross_attentions
522
+ key_states = past_key_value.key_cache[self.layer_idx]
523
+ value_states = past_key_value.value_cache[self.layer_idx]
524
+ else:
525
+ key_states = self.k_proj(current_states).reshape(seq_length, self.num_heads, -1)
526
+ value_states = self.v_proj(current_states).reshape(seq_length, self.num_heads, -1)
527
+ if past_key_value is not None:
528
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
529
+ cache_position = cache_position if not is_cross_attention else None
530
+ key_states, value_states = past_key_value.update(
531
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
532
+ )
533
+
534
+ query_states = query_states.transpose(0, 1)
535
+ key_states = key_states.transpose(0, 1)
536
+ value_states = value_states.transpose(0, 1)
537
+ attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
538
+
539
+ attention_mask = torch.full(
540
+ [1, seq_length, key_states.shape[1]],
541
+ torch.finfo(query_states.dtype).min,
542
+ device=query_states.device,
543
+ dtype=query_states.dtype,
544
+ )
545
+ assert cu_seqlens is not None
546
+ for i in range(1, cu_seqlens.size(0)):
547
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
548
+
549
+ attn_weights = attn_weights + attention_mask
550
+
551
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
552
+
553
+ if layer_head_mask is not None:
554
+ if layer_head_mask.size() != (self.num_heads,):
555
+ raise ValueError(
556
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
557
+ f" {layer_head_mask.size()}"
558
+ )
559
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
560
+
561
+ attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1).reshape(seq_length, self.embed_dim)
562
+
563
+ attn_output = self.out_proj(attn_output)
564
+
565
+ return attn_output, attn_weights, past_key_value
566
+
567
+
568
+ class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention):
569
+
570
+ def __init__(self, *args, **kwargs):
571
+ super().__init__(*args, **kwargs)
572
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
573
+
574
+ def forward(
575
+ self,
576
+ hidden_states: torch.Tensor,
577
+ key_value_states: Optional[torch.Tensor] = None,
578
+ past_key_value: Optional[EncoderDecoderCache] = None,
579
+ cu_seqlens: Optional[torch.Tensor] = None,
580
+ layer_head_mask: Optional[torch.Tensor] = None,
581
+ output_attentions: bool = False,
582
+ cache_position: Optional[torch.LongTensor] = None,
583
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
584
+ if isinstance(past_key_value, StaticCache):
585
+ raise ValueError(
586
+ "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
587
+ "Use `attn_implementation='sdpa'` in the meantime, and open an issue "
588
+ "at https://github.com/huggingface/transformers"
589
+ )
590
+ # Qwen2.5OmniThinkerFlashAttention2 attention does not support output_attentions
591
+ if output_attentions:
592
+ raise ValueError("Qwen2.5OmniThinkerFlashAttention2 attention does not support output_attentions")
593
+
594
+ # if key_value_states are provided this layer is used as a cross-attention layer
595
+ # for the decoder
596
+ is_cross_attention = key_value_states is not None
597
+ seq_length, all_dim = hidden_states.size()
598
+ query_states = (hidden_states @ self.q_proj.weight.t()) + (
599
+ self.q_proj.bias if self.q_proj.bias is not None else 0
600
+ )
601
+ query_states = query_states.reshape(seq_length, self.num_heads, -1)
602
+
603
+ if past_key_value is not None:
604
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
605
+ if is_cross_attention:
606
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
607
+ past_key_value.is_updated[self.layer_idx] = True
608
+ past_key_value = past_key_value.cross_attention_cache
609
+ else:
610
+ past_key_value = past_key_value.self_attention_cache
611
+
612
+ # use key_value_states if cross attention
613
+ current_states = key_value_states if key_value_states is not None else hidden_states
614
+ if is_cross_attention and past_key_value and is_updated:
615
+ # reuse k,v, cross_attentions
616
+ key_states = past_key_value.key_cache[self.layer_idx]
617
+ value_states = past_key_value.value_cache[self.layer_idx]
618
+ else:
619
+ key_states = (current_states @ self.k_proj.weight.t()) + (
620
+ self.k_proj.bias if self.k_proj.bias is not None else 0
621
+ )
622
+ key_states = key_states.reshape(seq_length, self.num_heads, -1)
623
+ value_states = (current_states @ self.v_proj.weight.t()) + (
624
+ self.v_proj.bias if self.v_proj.bias is not None else 0
625
+ )
626
+ value_states = value_states.reshape(seq_length, self.num_heads, -1)
627
+
628
+ if past_key_value is not None:
629
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
630
+ cache_position = cache_position if not is_cross_attention else None
631
+ key_states, value_states = past_key_value.update(
632
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
633
+ )
634
+ assert cu_seqlens is not None
635
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
636
+ attn_output = flash_attn_varlen_func(
637
+ query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
638
+ )
639
+ attn_output = attn_output.reshape(seq_length, all_dim)
640
+ attn_output = (attn_output @ self.out_proj.weight.t()) + (
641
+ self.out_proj.bias if self.out_proj.bias is not None else 0
642
+ )
643
+
644
+ if not output_attentions:
645
+ attn_weights = None
646
+
647
+ return attn_output, attn_weights, past_key_value
648
+
649
+
650
+ class Qwen2_5OmniAudioSdpaAttention(Qwen2_5OmniAudioAttention):
651
+ def forward(
652
+ self,
653
+ hidden_states: torch.Tensor,
654
+ key_value_states: Optional[torch.Tensor] = None,
655
+ past_key_value: Optional[EncoderDecoderCache] = None,
656
+ cu_seqlens: Optional[torch.Tensor] = None,
657
+ layer_head_mask: Optional[torch.Tensor] = None,
658
+ output_attentions: bool = False,
659
+ cache_position: Optional[torch.LongTensor] = None,
660
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
661
+ """Input shape: Batch x Time x Channel"""
662
+ if output_attentions or layer_head_mask is not None:
663
+ logger.warning_once(
664
+ "Qwen2_5OmniThinkerModel is using Qwen2_5OmniThinkerSdpaAttention, but "
665
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` "
666
+ "or `layer_head_mask` not None. Falling back to the manual attention"
667
+ ' implementation, but specifying the manual implementation will be required "'
668
+ '"from Transformers version v5.0.0 onwards. This warning can be removed using the argument"'
669
+ '" `attn_implementation="eager"` when loading the model.'
670
+ )
671
+ return super().forward(
672
+ hidden_states,
673
+ key_value_states=key_value_states,
674
+ past_key_value=past_key_value,
675
+ cu_seqlens=cu_seqlens,
676
+ layer_head_mask=layer_head_mask,
677
+ output_attentions=output_attentions,
678
+ cache_position=cache_position,
679
+ )
680
+
681
+ # if key_value_states are provided this layer is used as a cross-attention layer
682
+ # for the decoder
683
+ is_cross_attention = key_value_states is not None
684
+ seq_length, _ = hidden_states.size()
685
+
686
+ # get query proj
687
+ query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
688
+
689
+ if past_key_value is not None:
690
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
691
+ if is_cross_attention:
692
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
693
+ past_key_value.is_updated[self.layer_idx] = True
694
+ past_key_value = past_key_value.cross_attention_cache
695
+ else:
696
+ past_key_value = past_key_value.self_attention_cache
697
+
698
+ # use key_value_states if cross attention
699
+ current_states = key_value_states if key_value_states is not None else hidden_states
700
+ if is_cross_attention and past_key_value and is_updated:
701
+ # reuse k,v, cross_attentions
702
+ key_states = past_key_value.key_cache[self.layer_idx]
703
+ value_states = past_key_value.value_cache[self.layer_idx]
704
+ else:
705
+ key_states = self.k_proj(current_states).reshape(seq_length, self.num_heads, -1)
706
+ value_states = self.v_proj(current_states).reshape(seq_length, self.num_heads, -1)
707
+ if past_key_value is not None:
708
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
709
+ cache_position = cache_position if not is_cross_attention else None
710
+ key_states, value_states = past_key_value.update(
711
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
712
+ )
713
+
714
+ attention_mask = torch.zeros([1, seq_length, key_states.shape[0]], device=query_states.device, dtype=torch.bool)
715
+ assert cu_seqlens is not None
716
+ for i in range(1, cu_seqlens.size(0)):
717
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
718
+
719
+ query_states = query_states.transpose(0, 1)
720
+ key_states = key_states.transpose(0, 1)
721
+ value_states = value_states.transpose(0, 1)
722
+
723
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
724
+ query_states,
725
+ key_states,
726
+ value_states,
727
+ attn_mask=attention_mask,
728
+ dropout_p=self.dropout if self.training else 0.0,
729
+ )
730
+ attn_output = attn_output.transpose(0, 1)
731
+
732
+ attn_output = attn_output.reshape(seq_length, self.embed_dim)
733
+ attn_output = self.out_proj(attn_output)
734
+ return attn_output, None, past_key_value
735
+
736
+
737
+ QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES = {
738
+ "eager": Qwen2_5OmniAudioAttention,
739
+ "flash_attention_2": Qwen2_5OmniAudioFlashAttention2,
740
+ "sdpa": Qwen2_5OmniAudioSdpaAttention,
741
+ }
742
+
743
+
744
+ class Qwen2_5OmniAudioEncoderLayer(nn.Module):
745
+ def __init__(self, config: Qwen2_5OmniAudioEncoderConfig):
746
+ super().__init__()
747
+ self.embed_dim = config.d_model
748
+ self.self_attn = QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES[config._attn_implementation](
749
+ embed_dim=self.embed_dim,
750
+ num_heads=config.encoder_attention_heads,
751
+ dropout=config.attention_dropout,
752
+ config=config,
753
+ )
754
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
755
+ self.dropout = config.dropout
756
+ self.activation_fn = ACT2FN[config.activation_function]
757
+ self.activation_dropout = config.activation_dropout
758
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
759
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
760
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
761
+
762
+ def forward(
763
+ self,
764
+ hidden_states: torch.Tensor,
765
+ cu_seqlens: torch.Tensor,
766
+ layer_head_mask: torch.Tensor,
767
+ output_attentions: bool = False,
768
+ ):
769
+ residual = hidden_states
770
+ hidden_states = self.self_attn_layer_norm(hidden_states)
771
+ hidden_states, attn_weights, _ = self.self_attn(
772
+ hidden_states=hidden_states,
773
+ cu_seqlens=cu_seqlens,
774
+ layer_head_mask=layer_head_mask,
775
+ output_attentions=output_attentions,
776
+ )
777
+ hidden_states = residual + hidden_states
778
+ residual = hidden_states
779
+ hidden_states = self.final_layer_norm(hidden_states)
780
+ hidden_states = (hidden_states @ self.fc1.weight.t()) + (self.fc1.bias if self.fc1.bias is not None else 0)
781
+ hidden_states = self.activation_fn(hidden_states)
782
+ hidden_states = (hidden_states @ self.fc2.weight.t()) + (self.fc2.bias if self.fc2.bias is not None else 0)
783
+ hidden_states = residual + hidden_states
784
+
785
+ if hidden_states.dtype == torch.float16 and (
786
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
787
+ ):
788
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
789
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
790
+
791
+ outputs: Tuple[Any, ...]
792
+ outputs = (hidden_states,)
793
+
794
+ if output_attentions and attn_weights is not None:
795
+ outputs += (attn_weights,)
796
+
797
+ return outputs
798
+
799
+
800
+ class SinusoidsPositionEmbedding(nn.Module):
801
+ def __init__(self, length, channels, max_timescale=10000):
802
+ super().__init__()
803
+ assert channels % 2 == 0
804
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
805
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
806
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
807
+ self.register_buffer(
808
+ "positional_embedding",
809
+ torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
810
+ persistent=False,
811
+ )
812
+
813
+ def forward(self, seqlen: int):
814
+ return self.positional_embedding[:seqlen, :]
815
+
816
+
817
+ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
818
+ """
819
+ Transformer encoder consisting of *config.encoder_layers* self
820
+ attention layers. Each layer is a [`Qwen2_5OmniAudioEncoderLayer`].
821
+
822
+ Args:
823
+ config: Qwen2_5OmniAudioEncoderConfig
824
+ """
825
+
826
+ config_class = Qwen2_5OmniAudioEncoderConfig
827
+ main_input_name = "input_features"
828
+ _no_split_modules = ["Qwen2_5OmniAudioEncoderLayer"]
829
+ _supports_sdpa = True
830
+
831
+ def __init__(self, config: Qwen2_5OmniAudioEncoderConfig):
832
+ super().__init__(config)
833
+ self.dropout = config.dropout
834
+ self.layerdrop = config.encoder_layerdrop
835
+
836
+ embed_dim = config.d_model
837
+ self.num_mel_bins = config.num_mel_bins
838
+ self.max_source_positions = config.max_source_positions
839
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
840
+ self.n_window = config.n_window
841
+ self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
842
+ self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
843
+ self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim)
844
+ self.audio_bos_eos_token = nn.Embedding(2, config.output_dim)
845
+ self.layers = nn.ModuleList([Qwen2_5OmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)])
846
+ self.ln_post = nn.LayerNorm(config.d_model)
847
+ self.avg_pooler = nn.AvgPool1d(2, stride=2)
848
+ self.proj = nn.Linear(config.d_model, config.output_dim)
849
+ self.gradient_checkpointing = False
850
+ # Initialize weights and apply final processing
851
+ self.post_init()
852
+
853
+ def _freeze_parameters(self):
854
+ for param in self.parameters():
855
+ param.requires_grad = False
856
+ self._requires_grad = False
857
+
858
+ def get_input_embeddings(self) -> nn.Module:
859
+ return self.conv1
860
+
861
+ def set_input_embeddings(self, value):
862
+ self.conv1 = value
863
+
864
+ def forward(
865
+ self,
866
+ input_features,
867
+ feature_lens=None,
868
+ aftercnn_lens=None,
869
+ head_mask=None,
870
+ output_attentions=None,
871
+ output_hidden_states=None,
872
+ return_dict=None,
873
+ ):
874
+
875
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
876
+ output_hidden_states = (
877
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
878
+ )
879
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
880
+
881
+ chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
882
+
883
+ chunk_lengths = torch.tensor(
884
+ [self.n_window * 2] * chunk_num.sum(),
885
+ dtype=torch.long,
886
+ device=feature_lens.device,
887
+ )
888
+ tail_chunk_index = list(accumulate(chunk_num.tolist(), func=operator.add, initial=-1))[1:]
889
+ chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
890
+ chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths)
891
+
892
+ chunk_list = input_features.split(chunk_lengths.tolist(), dim=1)
893
+ padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function(
894
+ chunk_list, chunk_lengths, padding_value=0, padding_side="right"
895
+ )
896
+ padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask
897
+ padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2)
898
+
899
+ padded_embed = padded_embed + self.positional_embedding.positional_embedding[
900
+ : padded_embed.shape[1], :
901
+ ].unsqueeze(0).to(padded_embed.dtype)
902
+ hidden_states = padded_embed[padded_mask_after_cnn]
903
+ cu_seqlens = torch.cat(
904
+ (
905
+ torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32),
906
+ padded_mask_after_cnn.sum(1).cumsum(0),
907
+ )
908
+ ).to(torch.int32)
909
+ encoder_states: Optional[Tuple[Any, ...]] = () if output_hidden_states else None
910
+ all_attentions: Optional[Tuple[Any, ...]] = () if output_attentions else None
911
+
912
+ tmp_hidden_states = []
913
+ # check if head_mask has a correct number of layers specified if desired
914
+ if head_mask is not None and head_mask.size()[0] != (len(self.layers)):
915
+ raise ValueError(
916
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
917
+ )
918
+
919
+ for idx, encoder_layer in enumerate(self.layers):
920
+ if output_hidden_states and encoder_states is not None and hidden_states is not None:
921
+ encoder_states = encoder_states + (hidden_states,)
922
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
923
+ to_drop = False
924
+ if self.training:
925
+ dropout_probability = torch.rand([])
926
+ if dropout_probability < self.layerdrop: # skip the layer
927
+ to_drop = True
928
+
929
+ # Ignore copy
930
+ if to_drop:
931
+ layer_outputs = (None, None)
932
+ else:
933
+ if self.gradient_checkpointing and self.training:
934
+ layer_outputs = self._gradient_checkpointing_func(
935
+ encoder_layer.__call__,
936
+ hidden_states,
937
+ cu_seqlens,
938
+ (head_mask[idx] if head_mask is not None else None),
939
+ output_attentions,
940
+ )
941
+ else:
942
+ layer_outputs = encoder_layer(
943
+ hidden_states,
944
+ cu_seqlens,
945
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
946
+ output_attentions=output_attentions,
947
+ )
948
+
949
+ hidden_states = layer_outputs[0]
950
+ tmp_hidden_states.append(hidden_states)
951
+
952
+ if output_attentions and all_attentions is not None and layer_outputs is not None:
953
+ all_attentions = all_attentions + (layer_outputs[1],)
954
+
955
+ hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
956
+ token_audio_list = []
957
+ for each_audio_states in hidden_states_list:
958
+ each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1)
959
+ each_audio_states = self.ln_post(each_audio_states)
960
+ each_audio_states = self.proj(each_audio_states)
961
+ token_audio_list.append(each_audio_states)
962
+ token_audio = torch.cat(token_audio_list, dim=0)
963
+ if output_hidden_states and encoder_states is not None and token_audio is not None:
964
+ encoder_states = encoder_states + (token_audio,)
965
+
966
+ if not return_dict:
967
+ return tuple(v for v in [token_audio, encoder_states, all_attentions] if v is not None)
968
+ return BaseModelOutput(last_hidden_state=token_audio, hidden_states=encoder_states, attentions=all_attentions)
969
+
970
+ def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"):
971
+ max_len = tensor_len.max()
972
+ dim = tensor_list[0].shape[0]
973
+ padded_tensor = torch.full(
974
+ size=(len(tensor_list), dim, max_len),
975
+ fill_value=padding_value,
976
+ dtype=tensor_list[0].dtype,
977
+ device=tensor_list[0].device,
978
+ )
979
+
980
+ batch_mask = torch.zeros(
981
+ (len(tensor_len), max_len),
982
+ dtype=torch.long,
983
+ device=padded_tensor.device,
984
+ )
985
+ for i, length in enumerate(tensor_len):
986
+ batch_mask[i, :length] = 1
987
+ padded_tensor[i, :, :length] = tensor_list[i]
988
+
989
+ feature_lens_after_cnn = (tensor_len - 1) // 2 + 1
990
+ max_len_after_cnn = feature_lens_after_cnn.max()
991
+ batch_mask_after_cnn = torch.zeros(
992
+ (len(tensor_len), max_len_after_cnn),
993
+ dtype=torch.long,
994
+ device=padded_tensor.device,
995
+ )
996
+ for i, length in enumerate(feature_lens_after_cnn):
997
+ batch_mask_after_cnn[i, :length] = 1
998
+ return (
999
+ padded_tensor,
1000
+ batch_mask.unsqueeze(1),
1001
+ batch_mask_after_cnn.bool(),
1002
+ )
1003
+
1004
+ # Ignore copy
1005
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
1006
+ """
1007
+ Computes the output length of the convolutional layers and the output length of the audio encoder
1008
+ """
1009
+ input_lengths = (input_lengths - 1) // 2 + 1
1010
+ output_lengths = (input_lengths - 2) // 2 + 1
1011
+ return input_lengths, output_lengths
1012
+
1013
+
1014
+ def rotate_half(x):
1015
+ """Rotates half the hidden dims of the input."""
1016
+ x1 = x[..., : x.shape[-1] // 2]
1017
+ x2 = x[..., x.shape[-1] // 2 :]
1018
+ return torch.cat((-x2, x1), dim=-1)
1019
+
1020
+
1021
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
1022
+ orig_dtype = tensor.dtype
1023
+ tensor = tensor.float()
1024
+ cos = freqs.cos()
1025
+ sin = freqs.sin()
1026
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
1027
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
1028
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
1029
+ output = output.to(orig_dtype)
1030
+ return output
1031
+
1032
+
1033
+ class Qwen2_5OmniVisionAttention(nn.Module):
1034
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
1035
+ super().__init__()
1036
+ self.num_heads = num_heads
1037
+ self.head_dim = dim // num_heads
1038
+ self.q = nn.Linear(dim, dim, bias=True)
1039
+ self.k = nn.Linear(dim, dim, bias=True)
1040
+ self.v = nn.Linear(dim, dim, bias=True)
1041
+ self.proj = nn.Linear(dim, dim)
1042
+
1043
+ def forward(
1044
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
1045
+ ) -> torch.Tensor:
1046
+ seq_length = hidden_states.shape[0]
1047
+ q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
1048
+ k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
1049
+ v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
1050
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
1051
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
1052
+
1053
+ attention_mask = torch.full(
1054
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
1055
+ )
1056
+ for i in range(1, len(cu_seqlens)):
1057
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
1058
+
1059
+ q = q.transpose(0, 1)
1060
+ k = k.transpose(0, 1)
1061
+ v = v.transpose(0, 1)
1062
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
1063
+ attn_weights = attn_weights + attention_mask
1064
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
1065
+ attn_output = torch.matmul(attn_weights, v)
1066
+ attn_output = attn_output.transpose(0, 1)
1067
+ attn_output = attn_output.reshape(seq_length, -1)
1068
+ attn_output = self.proj(attn_output)
1069
+ return attn_output
1070
+
1071
+
1072
+ class Qwen2_5OmniVisionFlashAttention2(nn.Module):
1073
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
1074
+ super().__init__()
1075
+ self.num_heads = num_heads
1076
+ self.q = nn.Linear(dim, dim, bias=True)
1077
+ self.k = nn.Linear(dim, dim, bias=True)
1078
+ self.v = nn.Linear(dim, dim, bias=True)
1079
+ self.proj = nn.Linear(dim, dim)
1080
+
1081
+ def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
1082
+ tensor_ = tensor.float()
1083
+ cos = freqs.cos() # .type_as(tensor_)
1084
+ sin = freqs.sin() # .type_as(tensor_)
1085
+ output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
1086
+ return output
1087
+
1088
+ def forward(
1089
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
1090
+ ) -> torch.Tensor:
1091
+ seq_length = hidden_states.shape[0]
1092
+ q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
1093
+ k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
1094
+ v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
1095
+ q = self._apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
1096
+ k = self._apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
1097
+
1098
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
1099
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
1100
+ seq_length, -1
1101
+ )
1102
+ attn_output = self.proj(attn_output)
1103
+ return attn_output
1104
+
1105
+
1106
+ class Qwen2_5OmniVisionSdpaAttention(nn.Module):
1107
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
1108
+ super().__init__()
1109
+ self.num_heads = num_heads
1110
+ self.q = nn.Linear(dim, dim, bias=True)
1111
+ self.k = nn.Linear(dim, dim, bias=True)
1112
+ self.v = nn.Linear(dim, dim, bias=True)
1113
+ self.proj = nn.Linear(dim, dim)
1114
+
1115
+ def forward(
1116
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
1117
+ ) -> torch.Tensor:
1118
+ seq_length = hidden_states.shape[0]
1119
+ q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
1120
+ k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
1121
+ v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
1122
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
1123
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
1124
+
1125
+ attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
1126
+ for i in range(1, len(cu_seqlens)):
1127
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
1128
+ q = q.transpose(0, 1)
1129
+ k = k.transpose(0, 1)
1130
+ v = v.transpose(0, 1)
1131
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
1132
+ attn_output = attn_output.transpose(0, 1)
1133
+ attn_output = attn_output.reshape(seq_length, -1)
1134
+ attn_output = self.proj(attn_output)
1135
+ return attn_output
1136
+
1137
+
1138
+ class Qwen2_5OmniMLP(nn.Module):
1139
+ def __init__(self, config, bias: bool = False):
1140
+ super().__init__()
1141
+ self.hidden_size = config.hidden_size
1142
+ self.intermediate_size = config.intermediate_size
1143
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
1144
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
1145
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
1146
+ self.act_fn = ACT2FN[config.hidden_act]
1147
+
1148
+ def forward(self, hidden_state):
1149
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
1150
+
1151
+
1152
+ class Qwen2RMSNorm(nn.Module):
1153
+ def __init__(self, hidden_size, eps=1e-6):
1154
+ """
1155
+ Qwen2RMSNorm is equivalent to T5LayerNorm
1156
+ """
1157
+ super().__init__()
1158
+ self.weight = nn.Parameter(torch.ones(hidden_size))
1159
+ self.variance_epsilon = eps
1160
+
1161
+ def forward(self, hidden_states):
1162
+ input_dtype = hidden_states.dtype
1163
+ hidden_states = hidden_states.to(torch.float32)
1164
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
1165
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
1166
+ return self.weight * hidden_states.to(input_dtype)
1167
+
1168
+ def extra_repr(self):
1169
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
1170
+
1171
+
1172
+ QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = {
1173
+ "eager": Qwen2_5OmniVisionAttention,
1174
+ "flash_attention_2": Qwen2_5OmniVisionFlashAttention2,
1175
+ "sdpa": Qwen2_5OmniVisionSdpaAttention,
1176
+ }
1177
+
1178
+
1179
+ class Qwen2_5OmniVisionBlock(nn.Module):
1180
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
1181
+ super().__init__()
1182
+ self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
1183
+ self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
1184
+ self.attn = QWEN2_5_OMNI_VISION_ATTENTION_CLASSES[attn_implementation](
1185
+ config.hidden_size, num_heads=config.num_heads
1186
+ )
1187
+ self.mlp = Qwen2_5OmniMLP(config, bias=True)
1188
+
1189
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
1190
+ hidden_states = hidden_states + self.attn(
1191
+ self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
1192
+ )
1193
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
1194
+ return hidden_states
1195
+
1196
+
1197
+ class Qwen2_5_VisionPatchEmbed(nn.Module):
1198
+ def __init__(
1199
+ self,
1200
+ patch_size: int = 14,
1201
+ temporal_patch_size: int = 2,
1202
+ in_channels: int = 3,
1203
+ embed_dim: int = 1152,
1204
+ ) -> None:
1205
+ super().__init__()
1206
+ self.patch_size = patch_size
1207
+ self.temporal_patch_size = temporal_patch_size
1208
+ self.in_channels = in_channels
1209
+ self.embed_dim = embed_dim
1210
+
1211
+ kernel_size = (temporal_patch_size, patch_size, patch_size)
1212
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
1213
+
1214
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1215
+ target_dtype = self.proj.weight.dtype
1216
+ hidden_states = hidden_states.view(
1217
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
1218
+ )
1219
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
1220
+ return hidden_states
1221
+
1222
+
1223
+ class Qwen2_5_VisionRotaryEmbedding(nn.Module):
1224
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
1225
+ super().__init__()
1226
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
1227
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
1228
+
1229
+ def forward(self, seqlen: int) -> torch.Tensor:
1230
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
1231
+ freqs = torch.outer(seq, self.inv_freq)
1232
+ return freqs
1233
+
1234
+
1235
+ class Qwen2_5OmniPatchMerger(nn.Module):
1236
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
1237
+ super().__init__()
1238
+ self.hidden_size = context_dim * (spatial_merge_size**2)
1239
+ self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
1240
+ self.mlp = nn.Sequential(
1241
+ nn.Linear(self.hidden_size, self.hidden_size),
1242
+ nn.GELU(),
1243
+ nn.Linear(self.hidden_size, dim),
1244
+ )
1245
+
1246
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1247
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
1248
+ return x
1249
+
1250
+
1251
+ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
1252
+ config_class = Qwen2_5OmniVisionEncoderConfig
1253
+ _no_split_modules = ["Qwen2_5OmniVisionBlock"]
1254
+
1255
+ def __init__(self, config, *inputs, **kwargs) -> None:
1256
+ super().__init__(config, *inputs, **kwargs)
1257
+ self.spatial_merge_size = config.spatial_merge_size
1258
+ self.patch_size = config.patch_size
1259
+ self.fullatt_block_indexes = config.fullatt_block_indexes
1260
+ self.window_size = config.window_size
1261
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
1262
+
1263
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
1264
+ patch_size=config.patch_size,
1265
+ temporal_patch_size=config.temporal_patch_size,
1266
+ in_channels=config.in_channels,
1267
+ embed_dim=config.hidden_size,
1268
+ )
1269
+
1270
+ head_dim = config.hidden_size // config.num_heads
1271
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
1272
+ self.blocks = nn.ModuleList(
1273
+ [Qwen2_5OmniVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
1274
+ )
1275
+ self.merger = Qwen2_5OmniPatchMerger(
1276
+ dim=config.out_hidden_size,
1277
+ context_dim=config.hidden_size,
1278
+ spatial_merge_size=config.spatial_merge_size,
1279
+ )
1280
+ self.gradient_checkpointing = False
1281
+
1282
+ def rot_pos_emb(self, grid_thw):
1283
+ pos_ids = []
1284
+ for t, h, w in grid_thw:
1285
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
1286
+ hpos_ids = hpos_ids.reshape(
1287
+ h // self.spatial_merge_size,
1288
+ self.spatial_merge_size,
1289
+ w // self.spatial_merge_size,
1290
+ self.spatial_merge_size,
1291
+ )
1292
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
1293
+ hpos_ids = hpos_ids.flatten()
1294
+
1295
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
1296
+ wpos_ids = wpos_ids.reshape(
1297
+ h // self.spatial_merge_size,
1298
+ self.spatial_merge_size,
1299
+ w // self.spatial_merge_size,
1300
+ self.spatial_merge_size,
1301
+ )
1302
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
1303
+ wpos_ids = wpos_ids.flatten()
1304
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
1305
+ pos_ids_p = torch.cat(pos_ids, dim=0)
1306
+ max_grid_size = grid_thw[:, 1:].max()
1307
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
1308
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids_p].flatten(1)
1309
+ return rotary_pos_emb
1310
+
1311
+ def get_window_index(self, grid_thw):
1312
+ window_index: list = []
1313
+ cu_window_seqlens: list = [0]
1314
+ window_index_id = 0
1315
+ vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
1316
+
1317
+ for grid_t, grid_h, grid_w in grid_thw:
1318
+ llm_grid_h, llm_grid_w = (
1319
+ grid_h // self.spatial_merge_size,
1320
+ grid_w // self.spatial_merge_size,
1321
+ )
1322
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
1323
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
1324
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
1325
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
1326
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
1327
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
1328
+ index_padded = index_padded.reshape(
1329
+ grid_t,
1330
+ num_windows_h,
1331
+ vit_merger_window_size,
1332
+ num_windows_w,
1333
+ vit_merger_window_size,
1334
+ )
1335
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
1336
+ grid_t,
1337
+ num_windows_h * num_windows_w,
1338
+ vit_merger_window_size,
1339
+ vit_merger_window_size,
1340
+ )
1341
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
1342
+ index_padded = index_padded.reshape(-1)
1343
+ index_new = index_padded[index_padded != -100]
1344
+ window_index.append(index_new + window_index_id)
1345
+ cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
1346
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
1347
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
1348
+ window_index_p = torch.cat(window_index, dim=0)
1349
+
1350
+ return window_index_p, cu_window_seqlens
1351
+
1352
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
1353
+ """
1354
+ Args:
1355
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
1356
+ The final hidden states of the model.
1357
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
1358
+ The temporal, height and width of feature shape of each image in LLM.
1359
+
1360
+ Returns:
1361
+ `torch.Tensor`: hidden_states.
1362
+ """
1363
+ hidden_states = self.patch_embed(hidden_states)
1364
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
1365
+
1366
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
1367
+ cu_window_seqlens = torch.tensor(
1368
+ cu_window_seqlens,
1369
+ device=hidden_states.device,
1370
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
1371
+ )
1372
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
1373
+
1374
+ seq_len, _ = hidden_states.size()
1375
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
1376
+ hidden_states = hidden_states[window_index, :, :]
1377
+ hidden_states = hidden_states.reshape(seq_len, -1)
1378
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
1379
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
1380
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
1381
+
1382
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
1383
+ dim=0,
1384
+ # Select dtype based on the following factors:
1385
+ # - FA2 requires that cu_seqlens_q must have dtype int32
1386
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
1387
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
1388
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
1389
+ )
1390
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
1391
+
1392
+ # Modification here
1393
+ for layer_num, blk in enumerate(self.blocks):
1394
+ if layer_num in self.fullatt_block_indexes:
1395
+ cu_seqlens_now = cu_seqlens
1396
+ else:
1397
+ cu_seqlens_now = cu_window_seqlens
1398
+ if self.gradient_checkpointing and self.training:
1399
+ hidden_states = self._gradient_checkpointing_func(
1400
+ blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb
1401
+ )
1402
+ else:
1403
+ hidden_states = blk(
1404
+ hidden_states,
1405
+ cu_seqlens=cu_seqlens_now,
1406
+ rotary_pos_emb=rotary_pos_emb,
1407
+ )
1408
+ hidden_states = self.merger(hidden_states)
1409
+ reverse_indices = torch.argsort(window_index)
1410
+ hidden_states = hidden_states[reverse_indices, :]
1411
+
1412
+ return hidden_states
1413
+
1414
+ def get_dtype(self) -> torch.dtype:
1415
+ return self.blocks[0].mlp.gate_proj.weight.dtype
1416
+
1417
+ def get_device(self) -> torch.device:
1418
+ return self.blocks[0].mlp.gate_proj.weight.device
1419
+
1420
+
1421
+ class Qwen2_5OmniRotaryEmbedding(nn.Module):
1422
+ def __init__(self, config: Qwen2_5OmniThinkerConfig, device=None):
1423
+ super().__init__()
1424
+ # BC: "rope_type" was originally "type"
1425
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1426
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1427
+ else:
1428
+ self.rope_type = "default"
1429
+ self.max_seq_len_cached = config.max_position_embeddings
1430
+ self.original_max_seq_len = config.max_position_embeddings
1431
+
1432
+ self.config = config
1433
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
1434
+
1435
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
1436
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
1437
+ self.original_inv_freq = self.inv_freq
1438
+
1439
+ def _dynamic_frequency_update(self, position_ids, device):
1440
+ """
1441
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
1442
+ 1 - growing beyond the cached sequence length (allow scaling)
1443
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
1444
+ """
1445
+ seq_len = torch.max(position_ids) + 1
1446
+ if seq_len > self.max_seq_len_cached: # growth
1447
+ inv_freq, self.attention_scaling = self.rope_init_fn(
1448
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
1449
+ )
1450
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
1451
+ self.max_seq_len_cached = seq_len
1452
+
1453
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
1454
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
1455
+ self.max_seq_len_cached = self.original_max_seq_len
1456
+
1457
+ @torch.no_grad()
1458
+ def forward(self, x, position_ids):
1459
+ if "dynamic" in self.rope_type:
1460
+ self._dynamic_frequency_update(position_ids, device=x.device)
1461
+
1462
+ # Core RoPE block. In contrast to other models, Qwen2_5Omni has different position ids for the grids
1463
+ # So we expand the inv_freq to shape (3, ...)
1464
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
1465
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
1466
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
1467
+ device_type = x.device.type
1468
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
1469
+ with torch.autocast(device_type=device_type, enabled=False):
1470
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
1471
+ emb = torch.cat((freqs, freqs), dim=-1)
1472
+ cos = emb.cos()
1473
+ sin = emb.sin()
1474
+
1475
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
1476
+ cos = cos * self.attention_scaling
1477
+ sin = sin * self.attention_scaling
1478
+
1479
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1480
+
1481
+
1482
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
1483
+ mrope_section = mrope_section * 2
1484
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
1485
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
1486
+
1487
+ q_embed = (q * cos) + (rotate_half(q) * sin)
1488
+ k_embed = (k * cos) + (rotate_half(k) * sin)
1489
+ return q_embed, k_embed
1490
+
1491
+
1492
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
1493
+ """
1494
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
1495
+ The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
1496
+ (batch, num_attention_heads, seqlen, head_dim)
1497
+ """
1498
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
1499
+ if n_rep == 1:
1500
+ return hidden_states
1501
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
1502
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
1503
+
1504
+
1505
+ class Qwen2_5OmniAttention(nn.Module):
1506
+
1507
+ def __init__(self, config: Qwen2_5OmniConfig, layer_idx: Optional[int] = None):
1508
+ super().__init__()
1509
+ self.config = config
1510
+ self.layer_idx = layer_idx
1511
+ if layer_idx is None:
1512
+ logger.warning_once(
1513
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
1514
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
1515
+ "when creating this class."
1516
+ )
1517
+
1518
+ self.hidden_size = config.hidden_size
1519
+ self.num_heads = config.num_attention_heads
1520
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
1521
+ self.num_key_value_heads = config.num_key_value_heads
1522
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
1523
+ self.is_causal = True
1524
+ self.attention_dropout = config.attention_dropout
1525
+ self.rope_scaling = config.rope_scaling
1526
+
1527
+ # if (self.head_dim * self.num_heads) != self.hidden_size:
1528
+ # raise ValueError(
1529
+ # f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
1530
+ # f" and `num_heads`: {self.num_heads})."
1531
+ # )
1532
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
1533
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
1534
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
1535
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
1536
+
1537
+ self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
1538
+
1539
+ def forward(
1540
+ self,
1541
+ hidden_states: torch.Tensor,
1542
+ attention_mask: Optional[torch.Tensor] = None,
1543
+ position_ids: Optional[torch.LongTensor] = None,
1544
+ past_key_value: Optional[Cache] = None,
1545
+ output_attentions: bool = False,
1546
+ use_cache: bool = False,
1547
+ cache_position: Optional[torch.LongTensor] = None,
1548
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
1549
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1550
+ bsz, q_len, _ = hidden_states.size()
1551
+
1552
+ query_states = self.q_proj(hidden_states)
1553
+ key_states = self.k_proj(hidden_states)
1554
+ value_states = self.v_proj(hidden_states)
1555
+
1556
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
1557
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
1558
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
1559
+
1560
+ assert position_embeddings is not None
1561
+ cos, sin = position_embeddings
1562
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
1563
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
1564
+ )
1565
+
1566
+ if past_key_value is not None:
1567
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
1568
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1569
+
1570
+ # repeat k/v heads if n_kv_heads < n_heads
1571
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
1572
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
1573
+
1574
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
1575
+
1576
+ if attention_mask is not None: # no matter the length, we just slice it
1577
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
1578
+ attn_weights = attn_weights + causal_mask
1579
+
1580
+ # Fix precision issues in Qwen2-VL float16 inference
1581
+ # Replace inf values with zeros in attention weights to prevent NaN propagation
1582
+ if query_states.dtype == torch.float16:
1583
+ attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
1584
+
1585
+ # upcast attention to fp32
1586
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
1587
+ attn_weights_p: torch.Tensor = nn.functional.dropout(
1588
+ attn_weights, p=self.attention_dropout, training=self.training
1589
+ )
1590
+ return_attn_weights: Optional[torch.Tensor] = attn_weights_p
1591
+ attn_output = torch.matmul(attn_weights_p, value_states)
1592
+
1593
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
1594
+ raise ValueError(
1595
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
1596
+ f" {attn_output.size()}"
1597
+ )
1598
+
1599
+ attn_output = attn_output.transpose(1, 2).contiguous()
1600
+ attn_output = attn_output.reshape(bsz, q_len, -1)
1601
+
1602
+ attn_output = self.o_proj(attn_output)
1603
+
1604
+ if not output_attentions:
1605
+ return_attn_weights = None
1606
+
1607
+ return attn_output, return_attn_weights, past_key_value
1608
+
1609
+
1610
+ class Qwen2MLP(nn.Module):
1611
+ def __init__(self, config, bias: bool = False):
1612
+ super().__init__()
1613
+ self.hidden_size = config.hidden_size
1614
+ self.intermediate_size = config.intermediate_size
1615
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
1616
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
1617
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
1618
+ self.act_fn = ACT2FN[config.hidden_act]
1619
+
1620
+ def forward(self, hidden_state):
1621
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
1622
+
1623
+
1624
+ class Qwen2_5OmniFlashAttention2(Qwen2_5OmniAttention):
1625
+
1626
+ def __init__(self, *args, **kwargs):
1627
+ super().__init__(*args, **kwargs)
1628
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
1629
+
1630
+ def forward(
1631
+ self,
1632
+ hidden_states: torch.Tensor,
1633
+ attention_mask: Optional[torch.Tensor] = None,
1634
+ position_ids: Optional[torch.LongTensor] = None,
1635
+ past_key_value: Optional[Cache] = None,
1636
+ output_attentions: bool = False,
1637
+ use_cache: bool = False,
1638
+ cache_position: Optional[torch.LongTensor] = None,
1639
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
1640
+ ):
1641
+ bsz, q_len, _ = hidden_states.size()
1642
+
1643
+ query_states = self.q_proj(hidden_states)
1644
+ key_states = self.k_proj(hidden_states)
1645
+ value_states = self.v_proj(hidden_states)
1646
+
1647
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
1648
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
1649
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
1650
+
1651
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
1652
+ assert position_embeddings is not None
1653
+ cos, sin = position_embeddings
1654
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
1655
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
1656
+ )
1657
+
1658
+ if past_key_value is not None:
1659
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
1660
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1661
+
1662
+ # repeat k/v heads if n_kv_heads < n_heads
1663
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
1664
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
1665
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
1666
+
1667
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
1668
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
1669
+ # cast them back in float16 just to be sure everything works as expected.
1670
+ input_dtype = query_states.dtype
1671
+ if input_dtype == torch.float32:
1672
+ if torch.is_autocast_enabled():
1673
+ target_dtype = torch.get_autocast_gpu_dtype()
1674
+ # Handle the case where the model is quantized
1675
+ elif hasattr(self.config, "_pre_quantization_dtype"):
1676
+ target_dtype = self.config._pre_quantization_dtype
1677
+ else:
1678
+ target_dtype = self.q_proj.weight.dtype
1679
+
1680
+ logger.warning_once(
1681
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
1682
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
1683
+ f" {target_dtype}."
1684
+ )
1685
+
1686
+ query_states = query_states.to(target_dtype)
1687
+ key_states = key_states.to(target_dtype)
1688
+ value_states = value_states.to(target_dtype)
1689
+
1690
+ # Reashape to the expected shape for Flash Attention
1691
+ query_states = query_states.transpose(1, 2)
1692
+ key_states = key_states.transpose(1, 2)
1693
+ value_states = value_states.transpose(1, 2)
1694
+
1695
+ if (
1696
+ self.config.use_sliding_window
1697
+ and getattr(self.config, "sliding_window", None) is not None
1698
+ and self.layer_idx >= self.config.max_window_layers
1699
+ ):
1700
+ sliding_window = self.config.sliding_window
1701
+ else:
1702
+ sliding_window = None
1703
+
1704
+ attn_output = _flash_attention_forward(
1705
+ query_states,
1706
+ key_states,
1707
+ value_states,
1708
+ attention_mask,
1709
+ q_len,
1710
+ dropout=dropout_rate,
1711
+ sliding_window=sliding_window,
1712
+ is_causal=self.is_causal,
1713
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
1714
+ )
1715
+
1716
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
1717
+ attn_output = self.o_proj(attn_output)
1718
+
1719
+ if not output_attentions:
1720
+ attn_weights = None
1721
+
1722
+ return attn_output, attn_weights, past_key_value
1723
+
1724
+
1725
+ class Qwen2_5OmniSdpaAttention(Qwen2_5OmniAttention):
1726
+ """
1727
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
1728
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
1729
+ SDPA API.
1730
+ """
1731
+
1732
+ # Adapted from Qwen2Attention.forward
1733
+ def forward(
1734
+ self,
1735
+ hidden_states: torch.Tensor,
1736
+ attention_mask: Optional[torch.Tensor] = None,
1737
+ position_ids: Optional[torch.LongTensor] = None,
1738
+ past_key_value: Optional[Cache] = None,
1739
+ output_attentions: bool = False,
1740
+ use_cache: bool = False,
1741
+ cache_position: Optional[torch.LongTensor] = None,
1742
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1743
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1744
+ if output_attentions:
1745
+ logger.warning_once(
1746
+ "Qwen2_5OmniModel is using Qwen2_5OmniSdpaAttention, but "
1747
+ "`torch.nn.functional.scaled_dot_product_attention`"
1748
+ " does not support `output_attentions=True`. Falling back to "
1749
+ "the manual attention implementation, "
1750
+ "but specifying the manual implementation will be required from "
1751
+ "Transformers version v5.0.0 onwards."
1752
+ ' This warning can be removed using the argument "'
1753
+ '"`attn_implementation="eager"` when loading the model.'
1754
+ )
1755
+ return super().forward(
1756
+ hidden_states=hidden_states,
1757
+ attention_mask=attention_mask,
1758
+ position_ids=position_ids,
1759
+ past_key_value=past_key_value,
1760
+ output_attentions=output_attentions,
1761
+ use_cache=use_cache,
1762
+ cache_position=cache_position,
1763
+ position_embeddings=position_embeddings,
1764
+ )
1765
+
1766
+ bsz, q_len, _ = hidden_states.size()
1767
+
1768
+ query_states = self.q_proj(hidden_states)
1769
+ key_states = self.k_proj(hidden_states)
1770
+ value_states = self.v_proj(hidden_states)
1771
+
1772
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
1773
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
1774
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
1775
+
1776
+ assert position_embeddings is not None
1777
+ cos, sin = position_embeddings
1778
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
1779
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
1780
+ )
1781
+
1782
+ if past_key_value is not None:
1783
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
1784
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1785
+
1786
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
1787
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
1788
+
1789
+ causal_mask = attention_mask
1790
+ if attention_mask is not None: # no matter the length, we just slice it
1791
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
1792
+
1793
+ if query_states.device.type == "cuda" and attention_mask is not None:
1794
+ query_states = query_states.contiguous()
1795
+ key_states = key_states.contiguous()
1796
+ value_states = value_states.contiguous()
1797
+
1798
+ is_causal = True if causal_mask is None and q_len > 1 else False
1799
+
1800
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
1801
+ query_states,
1802
+ key_states,
1803
+ value_states,
1804
+ attn_mask=causal_mask,
1805
+ dropout_p=self.attention_dropout if self.training else 0.0,
1806
+ is_causal=is_causal,
1807
+ )
1808
+
1809
+ attn_output = attn_output.transpose(1, 2).contiguous()
1810
+ attn_output = attn_output.view(bsz, q_len, -1)
1811
+
1812
+ attn_output = self.o_proj(attn_output)
1813
+
1814
+ return attn_output, None, past_key_value
1815
+
1816
+
1817
+ QWEN2_5_OMNI_ATTENTION_CLASSES = {
1818
+ "eager": Qwen2_5OmniAttention,
1819
+ "flash_attention_2": Qwen2_5OmniFlashAttention2,
1820
+ "sdpa": Qwen2_5OmniSdpaAttention,
1821
+ }
1822
+
1823
+
1824
+ class Qwen2_5OmniDecoderLayer(nn.Module):
1825
+ def __init__(self, config: Qwen2_5OmniConfig, layer_idx: int):
1826
+ super().__init__()
1827
+ self.hidden_size = config.hidden_size
1828
+
1829
+ if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
1830
+ logger.warning_once(
1831
+ f"Sliding Window Attention is enabled but not implemented for "
1832
+ f"`{config._attn_implementation}`; "
1833
+ f"unexpected results may be encountered."
1834
+ )
1835
+ self.self_attn = QWEN2_5_OMNI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
1836
+
1837
+ self.mlp = Qwen2MLP(config)
1838
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1839
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1840
+
1841
+ def forward(
1842
+ self,
1843
+ hidden_states: torch.Tensor,
1844
+ attention_mask: Optional[torch.Tensor] = None,
1845
+ position_ids: Optional[torch.LongTensor] = None,
1846
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1847
+ output_attentions: Optional[bool] = False,
1848
+ use_cache: Optional[bool] = False,
1849
+ cache_position: Optional[torch.LongTensor] = None,
1850
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
1851
+ **kwargs,
1852
+ ):
1853
+
1854
+ residual = hidden_states
1855
+
1856
+ hidden_states = self.input_layernorm(hidden_states)
1857
+
1858
+ # Self Attention
1859
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1860
+ hidden_states=hidden_states,
1861
+ attention_mask=attention_mask,
1862
+ position_ids=position_ids,
1863
+ past_key_value=past_key_value,
1864
+ output_attentions=output_attentions,
1865
+ use_cache=use_cache,
1866
+ cache_position=cache_position,
1867
+ position_embeddings=position_embeddings,
1868
+ )
1869
+ hidden_states = residual + hidden_states
1870
+
1871
+ # Fully Connected
1872
+ residual = hidden_states
1873
+ hidden_states = self.post_attention_layernorm(hidden_states)
1874
+ hidden_states = self.mlp(hidden_states)
1875
+ hidden_states = residual + hidden_states
1876
+
1877
+ outputs: Tuple[Any, ...]
1878
+ outputs = (hidden_states,)
1879
+
1880
+ if output_attentions:
1881
+ outputs += (self_attn_weights,)
1882
+
1883
+ if use_cache:
1884
+ outputs += (present_key_value,)
1885
+
1886
+ return outputs
1887
+
1888
+
1889
+ QWEN2_5OMNI_START_DOCSTRING = r"""add doc"""
1890
+
1891
+
1892
+ @add_start_docstrings(
1893
+ "The bare Qwen2.5OmniThinker Model outputting raw hidden-states without any specific head on top.",
1894
+ QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTextConfig"),
1895
+ )
1896
+ class Qwen2_5OmniThinkerModel(Qwen2_5OmniPreTrainedModel):
1897
+ config_class = Qwen2_5OmniTextConfig
1898
+ _no_split_modules = ["Qwen2_5OmniDecoderLayer"]
1899
+
1900
+ def __init__(self, config: Qwen2_5OmniTextConfig):
1901
+ super().__init__(config)
1902
+ self.padding_idx = config.pad_token_id
1903
+ self.vocab_size = config.vocab_size
1904
+
1905
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1906
+ self.layers = nn.ModuleList(
1907
+ [Qwen2_5OmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1908
+ )
1909
+ self._attn_implementation = config._attn_implementation
1910
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1911
+ self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
1912
+
1913
+ self.gradient_checkpointing = False
1914
+ # Initialize weights and apply final processing
1915
+ self.post_init()
1916
+
1917
+ def get_input_embeddings(self):
1918
+ return self.embed_tokens
1919
+
1920
+ def set_input_embeddings(self, value):
1921
+ self.embed_tokens = value
1922
+
1923
+ def forward(
1924
+ self,
1925
+ input_ids: Optional[torch.LongTensor] = None,
1926
+ attention_mask: Optional[torch.Tensor] = None,
1927
+ position_ids: Optional[torch.Tensor] = None,
1928
+ past_key_values=None,
1929
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1930
+ use_cache: Optional[bool] = None,
1931
+ output_attentions: Optional[bool] = None,
1932
+ output_hidden_states: Optional[bool] = None,
1933
+ return_dict: Optional[bool] = None,
1934
+ cache_position: Optional[torch.Tensor] = None,
1935
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1936
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1937
+ output_hidden_states = (
1938
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1939
+ )
1940
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1941
+
1942
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1943
+
1944
+ if (input_ids is None) ^ (inputs_embeds is not None):
1945
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1946
+
1947
+ if self.gradient_checkpointing and self.training:
1948
+ if use_cache:
1949
+ logger.warning_once(
1950
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1951
+ )
1952
+ use_cache = False
1953
+
1954
+ # torch.jit.trace() doesn't support cache objects in the output
1955
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
1956
+ past_key_values = DynamicCache()
1957
+
1958
+ if inputs_embeds is None:
1959
+ inputs_embeds = self.embed_tokens(input_ids)
1960
+
1961
+ if cache_position is None:
1962
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1963
+ cache_position = torch.arange(
1964
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1965
+ )
1966
+
1967
+ # the hard coded `3` is for temporal, height and width.
1968
+ if position_ids is None:
1969
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
1970
+ elif position_ids.dim() == 2:
1971
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
1972
+
1973
+ causal_mask = self._update_causal_mask(
1974
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1975
+ )
1976
+
1977
+ hidden_states = inputs_embeds
1978
+
1979
+ # create position embeddings to be shared across the decoder layers
1980
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1981
+
1982
+ # decoder layers
1983
+ all_hidden_states: Optional[Tuple[Any, ...]] = () if output_hidden_states else None
1984
+ all_self_attns: Optional[Tuple[Any, ...]] = () if output_attentions else None
1985
+ next_decoder_cache = None
1986
+
1987
+ for decoder_layer in self.layers:
1988
+ if output_hidden_states and hidden_states is not None and all_hidden_states is not None:
1989
+ all_hidden_states += (hidden_states,)
1990
+
1991
+ if self.gradient_checkpointing and self.training:
1992
+ layer_outputs = self._gradient_checkpointing_func(
1993
+ decoder_layer.__call__,
1994
+ hidden_states,
1995
+ causal_mask,
1996
+ position_ids,
1997
+ past_key_values,
1998
+ output_attentions,
1999
+ use_cache,
2000
+ cache_position,
2001
+ position_embeddings,
2002
+ )
2003
+ else:
2004
+ layer_outputs = decoder_layer(
2005
+ hidden_states,
2006
+ attention_mask=causal_mask,
2007
+ position_ids=position_ids,
2008
+ past_key_value=past_key_values,
2009
+ output_attentions=output_attentions,
2010
+ use_cache=use_cache,
2011
+ cache_position=cache_position,
2012
+ position_embeddings=position_embeddings,
2013
+ )
2014
+
2015
+ hidden_states = layer_outputs[0]
2016
+
2017
+ if use_cache:
2018
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
2019
+
2020
+ if output_attentions and layer_outputs is not None and all_self_attns is not None:
2021
+ all_self_attns += (layer_outputs[1],)
2022
+
2023
+ hidden_states = self.norm(hidden_states)
2024
+
2025
+ # add hidden states from the last decoder layer
2026
+ if output_hidden_states and all_hidden_states is not None:
2027
+ all_hidden_states += (hidden_states,)
2028
+
2029
+ next_cache = next_decoder_cache if use_cache else None
2030
+
2031
+ if not return_dict:
2032
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
2033
+ return BaseModelOutputWithPast(
2034
+ last_hidden_state=hidden_states,
2035
+ past_key_values=next_cache,
2036
+ hidden_states=all_hidden_states,
2037
+ attentions=all_self_attns,
2038
+ )
2039
+
2040
+ def _update_causal_mask(
2041
+ self,
2042
+ attention_mask,
2043
+ input_tensor: torch.Tensor,
2044
+ cache_position: torch.Tensor,
2045
+ past_key_values: Cache,
2046
+ output_attentions: bool,
2047
+ ):
2048
+ if self.config._attn_implementation == "flash_attention_2":
2049
+ if attention_mask is not None and past_key_values is not None:
2050
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
2051
+ if is_padding_right:
2052
+ raise ValueError(
2053
+ "You are attempting to perform batched generation with padding_side='right'"
2054
+ " this may lead to unexpected behaviour for Flash Attention version "
2055
+ "of Qwen25OmniThinker. Make sure to "
2056
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
2057
+ )
2058
+ if attention_mask is not None and 0.0 in attention_mask:
2059
+ return attention_mask
2060
+ return None
2061
+
2062
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
2063
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
2064
+ # to infer the attention mask.
2065
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
2066
+ using_static_cache = isinstance(past_key_values, StaticCache)
2067
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
2068
+
2069
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
2070
+ if (
2071
+ self.config._attn_implementation == "sdpa"
2072
+ and not (using_static_cache or using_sliding_window_cache)
2073
+ and not output_attentions
2074
+ ):
2075
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
2076
+ attention_mask,
2077
+ inputs_embeds=input_tensor,
2078
+ past_key_values_length=past_seen_tokens,
2079
+ sliding_window=self.config.sliding_window,
2080
+ is_training=self.training,
2081
+ ):
2082
+ return None
2083
+
2084
+ dtype, device = input_tensor.dtype, input_tensor.device
2085
+ min_dtype = torch.finfo(dtype).min
2086
+ sequence_length = input_tensor.shape[1]
2087
+ # SlidingWindowCache or StaticCache
2088
+ if using_sliding_window_cache or using_static_cache:
2089
+ target_length = past_key_values.get_max_cache_shape()
2090
+ # DynamicCache or no cache
2091
+ else:
2092
+ target_length = (
2093
+ attention_mask.shape[-1]
2094
+ if isinstance(attention_mask, torch.Tensor)
2095
+ else past_seen_tokens + sequence_length + 1
2096
+ )
2097
+
2098
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
2099
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
2100
+ attention_mask,
2101
+ sequence_length=sequence_length,
2102
+ target_length=target_length,
2103
+ dtype=dtype,
2104
+ device=device,
2105
+ cache_position=cache_position,
2106
+ batch_size=input_tensor.shape[0],
2107
+ config=self.config,
2108
+ past_key_values=past_key_values,
2109
+ )
2110
+
2111
+ if (
2112
+ self.config._attn_implementation == "sdpa"
2113
+ and attention_mask is not None
2114
+ and attention_mask.device.type in ["cuda", "xpu"]
2115
+ and not output_attentions
2116
+ ):
2117
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
2118
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
2119
+ # Details: https://github.com/pytorch/pytorch/issues/110213
2120
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
2121
+
2122
+ return causal_mask
2123
+
2124
+ @staticmethod
2125
+ def _prepare_4d_causal_attention_mask_with_cache_position(
2126
+ attention_mask: torch.Tensor,
2127
+ sequence_length: int,
2128
+ target_length: int,
2129
+ dtype: torch.dtype,
2130
+ device: torch.device,
2131
+ cache_position: torch.Tensor,
2132
+ batch_size: int,
2133
+ config: Qwen2_5OmniConfig,
2134
+ past_key_values: Cache,
2135
+ ):
2136
+ if attention_mask is not None and attention_mask.dim() == 4:
2137
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
2138
+ causal_mask = attention_mask
2139
+ else:
2140
+ min_dtype = torch.finfo(dtype).min
2141
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
2142
+ diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
2143
+ if config.sliding_window is not None:
2144
+
2145
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
2146
+ sliding_attend_mask = torch.arange(target_length, device=device) <= (
2147
+ cache_position.reshape(-1, 1) - config.sliding_window
2148
+ )
2149
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
2150
+ causal_mask *= diagonal_attend_mask
2151
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
2152
+ if attention_mask is not None:
2153
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
2154
+ if attention_mask.shape[-1] > target_length:
2155
+ attention_mask = attention_mask[:, :target_length]
2156
+ mask_length = attention_mask.shape[-1]
2157
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
2158
+ causal_mask.device
2159
+ )
2160
+ padding_mask = padding_mask == 0
2161
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
2162
+ padding_mask, min_dtype
2163
+ )
2164
+ return causal_mask
2165
+
2166
+
2167
+ @add_start_docstrings(
2168
+ """The Qwen2.5OmniThinker model which consists of a audio backbone and a language model.""",
2169
+ QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniThinkerConfig"),
2170
+ )
2171
+ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin):
2172
+ config_class = Qwen2_5OmniThinkerConfig
2173
+ _no_split_modules = ["Qwen2_5OmniAudioEncoder", "Qwen2_5OmniVisionEncoder"]
2174
+
2175
+ def __init__(self, config: Qwen2_5OmniThinkerConfig):
2176
+ super().__init__(config)
2177
+ self.audio_tower = Qwen2_5OmniAudioEncoder._from_config(
2178
+ config.audio_config, attn_implementation=config._attn_implementation
2179
+ )
2180
+
2181
+ self.visual = Qwen2_5OmniVisionEncoder._from_config(
2182
+ config.vision_config, attn_implementation=config._attn_implementation
2183
+ )
2184
+
2185
+ self.vocab_size = config.text_config.vocab_size
2186
+ self.model = Qwen2_5OmniThinkerModel._from_config(
2187
+ config.text_config, attn_implementation=config._attn_implementation
2188
+ )
2189
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
2190
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
2191
+ self.spatial_merge_size = config.vision_config.spatial_merge_size
2192
+ self.post_init()
2193
+
2194
+ def forward(
2195
+ self,
2196
+ input_ids: Optional[torch.LongTensor] = None,
2197
+ input_features: Optional[torch.Tensor] = None,
2198
+ pixel_values: Optional[torch.Tensor] = None,
2199
+ pixel_values_videos: Optional[torch.Tensor] = None,
2200
+ image_grid_thw: Optional[torch.LongTensor] = None,
2201
+ video_grid_thw: Optional[torch.LongTensor] = None,
2202
+ attention_mask: Optional[torch.Tensor] = None,
2203
+ feature_attention_mask: Optional[torch.Tensor] = None,
2204
+ audio_feature_lengths: Optional[torch.Tensor] = None,
2205
+ position_ids: Optional[torch.Tensor] = None,
2206
+ past_key_values: Optional[List[torch.Tensor]] = None,
2207
+ inputs_embeds: Optional[torch.Tensor] = None,
2208
+ rope_deltas: Optional[torch.Tensor] = None,
2209
+ labels: Optional[torch.LongTensor] = None,
2210
+ use_cache: Optional[bool] = None,
2211
+ output_attentions: Optional[bool] = None,
2212
+ output_hidden_states: Optional[bool] = None,
2213
+ return_dict: Optional[bool] = None,
2214
+ use_audio_in_video: Optional[bool] = None,
2215
+ cache_position: Optional[torch.Tensor] = None,
2216
+ video_second_per_grid: Optional[torch.LongTensor] = None,
2217
+ ) -> Union[Tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]:
2218
+
2219
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2220
+ output_hidden_states = (
2221
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2222
+ )
2223
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2224
+
2225
+ if feature_attention_mask is not None and input_features is not None:
2226
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
2227
+ input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
2228
+ else:
2229
+ audio_feature_lengths = None
2230
+ if attention_mask is not None and position_ids is None:
2231
+ if cache_position is None or (cache_position is not None and cache_position[0] == 0):
2232
+ delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
2233
+ position_ids_p, rope_deltas = self.get_rope_index(
2234
+ input_ids,
2235
+ image_grid_thw,
2236
+ video_grid_thw,
2237
+ attention_mask,
2238
+ use_audio_in_video,
2239
+ audio_feature_lengths,
2240
+ video_second_per_grid,
2241
+ )
2242
+ rope_deltas = rope_deltas - delta0
2243
+
2244
+ else:
2245
+ assert input_ids is not None
2246
+ batch_size, seq_length = input_ids.shape
2247
+ delta = (
2248
+ cache_position[0] + rope_deltas
2249
+ if cache_position is not None and rope_deltas is not None
2250
+ else torch.tensor(0, device=input_ids.device)
2251
+ )
2252
+ position_ids = torch.arange(seq_length, device=input_ids.device)
2253
+ position_ids_p = position_ids.view(1, -1).expand(batch_size, -1)
2254
+ position_ids_p = position_ids_p.add(delta)
2255
+ position_ids_p = position_ids_p.unsqueeze(0).expand(3, -1, -1)
2256
+
2257
+ if inputs_embeds is None and input_ids is not None:
2258
+ # 1. Extract the input embeddings
2259
+ inputs_embeds = self.get_input_embeddings()(input_ids)
2260
+ embeds_to_talker = inputs_embeds.clone()
2261
+
2262
+ # 2. Merge text , audios , image and video
2263
+ if input_ids.shape[1] != 1:
2264
+ if input_features is not None and feature_attention_mask is not None:
2265
+ audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
2266
+ audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
2267
+ )
2268
+ feature_lens = (
2269
+ audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
2270
+ )
2271
+ audio_outputs = self.audio_tower(
2272
+ input_features,
2273
+ feature_lens=feature_lens,
2274
+ aftercnn_lens=audio_feat_lengths,
2275
+ )
2276
+ audio_features = audio_outputs.last_hidden_state
2277
+ if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
2278
+ raise ValueError("length of audio_features should match audio_output_lengths")
2279
+ audio_mask = (input_ids == self.config.audio_token_index).unsqueeze(-1).expand_as(inputs_embeds)
2280
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
2281
+ inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
2282
+ embeds_to_talker = embeds_to_talker.masked_scatter(audio_mask, torch.zeros_like(audio_features))
2283
+
2284
+ if pixel_values is not None:
2285
+ pixel_values = pixel_values.type(self.visual.get_dtype())
2286
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
2287
+ image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
2288
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
2289
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
2290
+ embeds_to_talker = embeds_to_talker.masked_scatter(image_mask, torch.zeros_like(image_embeds))
2291
+
2292
+ if pixel_values_videos is not None:
2293
+ pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
2294
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
2295
+ video_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
2296
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
2297
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
2298
+ embeds_to_talker = embeds_to_talker.masked_scatter(video_mask, torch.zeros_like(video_embeds))
2299
+
2300
+ if attention_mask is not None:
2301
+ attention_mask = attention_mask.to(inputs_embeds.device)
2302
+
2303
+ outputs = self.model(
2304
+ attention_mask=attention_mask,
2305
+ position_ids=position_ids_p,
2306
+ past_key_values=past_key_values,
2307
+ inputs_embeds=inputs_embeds,
2308
+ use_cache=use_cache,
2309
+ output_attentions=output_attentions,
2310
+ output_hidden_states=output_hidden_states,
2311
+ return_dict=return_dict,
2312
+ cache_position=cache_position,
2313
+ )
2314
+
2315
+ hidden_states = outputs[0]
2316
+ logits = self.lm_head(hidden_states)
2317
+
2318
+ loss = None
2319
+ if labels is not None:
2320
+ logits = logits.float()
2321
+ # Shift so that tokens < n predict n
2322
+ if attention_mask is not None:
2323
+ shift_attention_mask = attention_mask[..., 1:]
2324
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
2325
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
2326
+ else:
2327
+ shift_logits = logits[..., :-1, :].contiguous()
2328
+ shift_labels = labels[..., 1:].contiguous()
2329
+ # Flatten the tokens
2330
+ loss_fct = nn.CrossEntropyLoss()
2331
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))
2332
+
2333
+ if not return_dict:
2334
+ output = (logits,) + ((embeds_to_talker, outputs[0])) + outputs[1:]
2335
+ return (loss,) + output if loss is not None else output
2336
+
2337
+ return Qwen2_5OmniThinkerCausalLMOutputWithPast(
2338
+ loss=loss,
2339
+ logits=logits,
2340
+ past_key_values=outputs.past_key_values,
2341
+ hidden_states=(embeds_to_talker, outputs.hidden_states),
2342
+ attentions=outputs.attentions,
2343
+ attention_mask=attention_mask,
2344
+ rope_deltas=rope_deltas,
2345
+ )
2346
+
2347
+ def prepare_inputs_for_generation(
2348
+ self,
2349
+ input_ids,
2350
+ past_key_values=None,
2351
+ attention_mask=None,
2352
+ inputs_embeds=None,
2353
+ cache_position=None,
2354
+ position_ids=None,
2355
+ use_cache=True,
2356
+ pixel_values=None,
2357
+ pixel_values_videos=None,
2358
+ image_grid_thw=None,
2359
+ video_grid_thw=None,
2360
+ input_features=None,
2361
+ feature_attention_mask=None,
2362
+ use_audio_in_video=False,
2363
+ video_second_per_grid=None,
2364
+ **kwargs,
2365
+ ):
2366
+ model_inputs = super().prepare_inputs_for_generation(
2367
+ input_ids,
2368
+ past_key_values=past_key_values,
2369
+ attention_mask=attention_mask,
2370
+ inputs_embeds=inputs_embeds,
2371
+ cache_position=cache_position,
2372
+ position_ids=position_ids,
2373
+ use_cache=use_cache,
2374
+ pixel_values=pixel_values,
2375
+ pixel_values_videos=pixel_values_videos,
2376
+ image_grid_thw=image_grid_thw,
2377
+ video_grid_thw=video_grid_thw,
2378
+ input_features=input_features,
2379
+ feature_attention_mask=feature_attention_mask,
2380
+ use_audio_in_video=use_audio_in_video,
2381
+ video_second_per_grid=video_second_per_grid,
2382
+ **kwargs,
2383
+ )
2384
+
2385
+ model_inputs["position_ids"] = None
2386
+
2387
+ if cache_position[0] != 0:
2388
+ model_inputs["pixel_values"] = None
2389
+ model_inputs["pixel_values_videos"] = None
2390
+
2391
+ return model_inputs
2392
+
2393
+ def _update_model_kwargs_for_generation(
2394
+ self,
2395
+ outputs: ModelOutput,
2396
+ model_kwargs: Dict[str, Any],
2397
+ is_encoder_decoder: bool = False,
2398
+ num_new_tokens: int = 1,
2399
+ ) -> Dict[str, Any]:
2400
+ # update attention_mask
2401
+ if getattr(outputs, "attention_mask", None) is not None:
2402
+ model_kwargs["attention_mask"] = outputs.attention_mask
2403
+
2404
+ model_kwargs = super()._update_model_kwargs_for_generation(
2405
+ outputs, model_kwargs, is_encoder_decoder, num_new_tokens
2406
+ )
2407
+
2408
+ if getattr(outputs, "rope_deltas", None) is not None:
2409
+ model_kwargs["rope_deltas"] = outputs.rope_deltas
2410
+
2411
+ return model_kwargs
2412
+
2413
+
2414
+ @dataclass
2415
+ class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput):
2416
+
2417
+ loss: Optional[torch.FloatTensor] = None
2418
+ logits: Optional[torch.FloatTensor] = None
2419
+ past_key_values: Optional[List[torch.FloatTensor]] = None
2420
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
2421
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
2422
+ attention_mask: Optional[torch.Tensor] = None
2423
+ rope_deltas: Optional[torch.LongTensor] = None
2424
+ thinker_reply_part: Optional[torch.Tensor] = None
2425
+
2426
+
2427
+ @add_start_docstrings(
2428
+ "The bare Qwen2.5OmniTalker Model outputting raw hidden-states without any specific head on top.",
2429
+ QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTalkerConfig"),
2430
+ )
2431
+ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
2432
+ config_class = Qwen2_5OmniTalkerConfig
2433
+ _no_split_modules = ["Qwen2_5OmniTalkerDecoderLayer"]
2434
+
2435
+ def __init__(self, config: Qwen2_5OmniTalkerConfig):
2436
+ super().__init__(config)
2437
+ self.padding_idx = config.pad_token_id
2438
+ self.vocab_size = config.vocab_size
2439
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size, self.padding_idx)
2440
+ self.layers = nn.ModuleList(
2441
+ [Qwen2_5OmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
2442
+ )
2443
+ self._attn_implementation = config._attn_implementation
2444
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2445
+ self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
2446
+
2447
+ self.gradient_checkpointing = False
2448
+ # Initialize weights and apply final processing
2449
+ self.post_init()
2450
+
2451
+ def get_input_embeddings(self):
2452
+ return self.embed_tokens
2453
+
2454
+ def set_input_embeddings(self, value):
2455
+ self.embed_tokens = value
2456
+
2457
+ def forward(
2458
+ self,
2459
+ input_ids: Optional[torch.LongTensor] = None,
2460
+ attention_mask: Optional[torch.Tensor] = None,
2461
+ position_ids: Optional[torch.Tensor] = None,
2462
+ past_key_values: Optional[Any] = None,
2463
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2464
+ use_cache: Optional[bool] = None,
2465
+ output_attentions: Optional[bool] = None,
2466
+ output_hidden_states: Optional[bool] = None,
2467
+ return_dict: Optional[bool] = None,
2468
+ cache_position: Optional[torch.Tensor] = None,
2469
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
2470
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2471
+ output_hidden_states = (
2472
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2473
+ )
2474
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
2475
+
2476
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2477
+
2478
+ if (input_ids is None) ^ (inputs_embeds is not None):
2479
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
2480
+
2481
+ if self.gradient_checkpointing and self.training:
2482
+ if use_cache:
2483
+ logger.warning_once(
2484
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
2485
+ )
2486
+ use_cache = False
2487
+
2488
+ # torch.jit.trace() doesn't support cache objects in the output
2489
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
2490
+ past_key_values = DynamicCache()
2491
+
2492
+ if inputs_embeds is None:
2493
+ inputs_embeds = self.embed_tokens(input_ids)
2494
+
2495
+ if cache_position is None:
2496
+ past_seen_tokens: Any
2497
+ if past_key_values is not None:
2498
+ past_seen_tokens = past_key_values.get_seq_length()
2499
+ else:
2500
+ past_seen_tokens = 0
2501
+ cache_position = torch.arange(
2502
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
2503
+ )
2504
+
2505
+ # the hard coded `3` is for temporal, height and width.
2506
+ if position_ids is None and cache_position is not None:
2507
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
2508
+ elif position_ids.dim() == 2:
2509
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
2510
+ assert attention_mask is not None and cache_position is not None
2511
+ causal_mask = self._update_causal_mask(
2512
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
2513
+ )
2514
+
2515
+ hidden_states = inputs_embeds
2516
+
2517
+ # create position embeddings to be shared across the decoder layers
2518
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
2519
+
2520
+ # decoder layers
2521
+ all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if output_hidden_states else None
2522
+ all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if output_attentions else None
2523
+ next_decoder_cache = None
2524
+
2525
+ for decoder_layer in self.layers:
2526
+ if output_hidden_states and all_hidden_states is not None and hidden_states is not None:
2527
+ all_hidden_states += (hidden_states,)
2528
+
2529
+ if self.gradient_checkpointing and self.training:
2530
+ layer_outputs = self._gradient_checkpointing_func(
2531
+ decoder_layer.__call__,
2532
+ hidden_states,
2533
+ causal_mask,
2534
+ position_ids,
2535
+ past_key_values,
2536
+ output_attentions,
2537
+ use_cache,
2538
+ cache_position,
2539
+ position_embeddings,
2540
+ )
2541
+ else:
2542
+ layer_outputs = decoder_layer(
2543
+ hidden_states,
2544
+ attention_mask=causal_mask,
2545
+ position_ids=position_ids,
2546
+ past_key_value=past_key_values,
2547
+ output_attentions=output_attentions,
2548
+ use_cache=use_cache,
2549
+ cache_position=cache_position,
2550
+ position_embeddings=position_embeddings,
2551
+ )
2552
+
2553
+ hidden_states = layer_outputs[0]
2554
+
2555
+ if use_cache:
2556
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
2557
+
2558
+ if output_attentions and all_self_attns is not None and layer_outputs is not None:
2559
+ all_self_attns += (layer_outputs[1],)
2560
+
2561
+ hidden_states = self.norm(hidden_states)
2562
+
2563
+ # add hidden states from the last decoder layer
2564
+ if output_hidden_states and all_hidden_states is not None and hidden_states is not None:
2565
+ all_hidden_states += (hidden_states,)
2566
+
2567
+ next_cache = next_decoder_cache if use_cache else None
2568
+
2569
+ if not return_dict:
2570
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
2571
+ return BaseModelOutputWithPast(
2572
+ last_hidden_state=hidden_states,
2573
+ past_key_values=next_cache,
2574
+ hidden_states=all_hidden_states,
2575
+ attentions=all_self_attns,
2576
+ )
2577
+
2578
+ def _update_causal_mask(
2579
+ self,
2580
+ attention_mask: torch.Tensor,
2581
+ input_tensor: torch.Tensor,
2582
+ cache_position: torch.Tensor,
2583
+ past_key_values: Cache,
2584
+ output_attentions: bool,
2585
+ ):
2586
+ if self.config._attn_implementation == "flash_attention_2":
2587
+ if attention_mask is not None and past_key_values is not None:
2588
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
2589
+ if is_padding_right:
2590
+ raise ValueError(
2591
+ "You are attempting to perform batched generation with padding_side='right'"
2592
+ " this may lead to unexpected behaviour for Flash Attention version "
2593
+ "of Qwen25OmniTalker. Make sure to "
2594
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
2595
+ )
2596
+ if attention_mask is not None and 0.0 in attention_mask:
2597
+ return attention_mask
2598
+ return None
2599
+
2600
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
2601
+ using_static_cache = isinstance(past_key_values, StaticCache)
2602
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
2603
+
2604
+ if (
2605
+ self.config._attn_implementation == "sdpa"
2606
+ and not (using_static_cache or using_sliding_window_cache)
2607
+ and not output_attentions
2608
+ ):
2609
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
2610
+ attention_mask,
2611
+ inputs_embeds=input_tensor,
2612
+ past_key_values_length=past_seen_tokens,
2613
+ sliding_window=self.config.sliding_window,
2614
+ is_training=self.training,
2615
+ ):
2616
+ return None
2617
+
2618
+ dtype, device = input_tensor.dtype, input_tensor.device
2619
+ min_dtype = torch.finfo(dtype).min
2620
+ sequence_length = input_tensor.shape[1]
2621
+ # SlidingWindowCache or StaticCache
2622
+ if using_sliding_window_cache or using_static_cache:
2623
+ target_length = past_key_values.get_max_cache_shape()
2624
+ # DynamicCache or no cache
2625
+ else:
2626
+ target_length = (
2627
+ attention_mask.shape[-1]
2628
+ if isinstance(attention_mask, torch.Tensor)
2629
+ else past_seen_tokens + sequence_length + 1
2630
+ )
2631
+
2632
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
2633
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
2634
+ attention_mask,
2635
+ sequence_length=sequence_length,
2636
+ target_length=target_length,
2637
+ dtype=dtype,
2638
+ device=device,
2639
+ cache_position=cache_position,
2640
+ batch_size=input_tensor.shape[0],
2641
+ config=self.config,
2642
+ past_key_values=past_key_values,
2643
+ )
2644
+
2645
+ if (
2646
+ self.config._attn_implementation == "sdpa"
2647
+ and attention_mask is not None
2648
+ and attention_mask.device.type in ["cuda", "xpu"]
2649
+ and not output_attentions
2650
+ ):
2651
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
2652
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
2653
+ # Details: https://github.com/pytorch/pytorch/issues/110213
2654
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
2655
+
2656
+ return causal_mask
2657
+
2658
+ @staticmethod
2659
+ def _prepare_4d_causal_attention_mask_with_cache_position(
2660
+ attention_mask: torch.Tensor,
2661
+ sequence_length: int,
2662
+ target_length: int,
2663
+ dtype: torch.dtype,
2664
+ device: torch.device,
2665
+ cache_position: torch.Tensor,
2666
+ batch_size: int,
2667
+ config: Qwen2_5OmniConfig,
2668
+ past_key_values: Cache,
2669
+ ):
2670
+ if attention_mask is not None and attention_mask.dim() == 4:
2671
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
2672
+ causal_mask = attention_mask
2673
+ else:
2674
+ min_dtype = torch.finfo(dtype).min
2675
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
2676
+ diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
2677
+ if config.sliding_window is not None:
2678
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
2679
+ sliding_attend_mask = torch.arange(target_length, device=device) <= (
2680
+ cache_position.reshape(-1, 1) - config.sliding_window
2681
+ )
2682
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
2683
+ causal_mask *= diagonal_attend_mask
2684
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
2685
+ if attention_mask is not None:
2686
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
2687
+ if attention_mask.shape[-1] > target_length:
2688
+ attention_mask = attention_mask[:, :target_length]
2689
+ mask_length = attention_mask.shape[-1]
2690
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
2691
+ causal_mask.device
2692
+ )
2693
+ padding_mask = padding_mask == 0
2694
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
2695
+ padding_mask, min_dtype
2696
+ )
2697
+ return causal_mask
2698
+
2699
+
2700
+ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin):
2701
+ config_class = Qwen2_5OmniTalkerConfig
2702
+
2703
+ def __init__(self, config: Qwen2_5OmniTalkerConfig):
2704
+ super().__init__(config)
2705
+
2706
+ self.thinker_to_talker_proj = nn.Linear(config.embedding_size, config.hidden_size)
2707
+
2708
+ self.model = Qwen2_5OmniTalkerModel(config)
2709
+ self.codebook_size = config.vocab_size
2710
+ self.codec_head = nn.Linear(config.hidden_size, self.codebook_size, bias=False)
2711
+
2712
+ self.codec_bos_token = config.tts_codec_start_token_id
2713
+ self.codec_eos_token = config.tts_codec_end_token_id
2714
+ self.codec_pad_token = config.tts_codec_pad_token_id
2715
+ self.codec_mask_token = config.tts_codec_mask_token_id
2716
+
2717
+ self.text_bos_token = config.tts_text_start_token_id
2718
+ self.text_eos_token = config.tts_text_end_token_id
2719
+ self.text_pad_token = config.tts_text_pad_token_id
2720
+
2721
+ self.spatial_merge_size = self.config.spatial_merge_size
2722
+
2723
+ self.post_init()
2724
+
2725
+ def forward(
2726
+ self,
2727
+ input_ids: torch.LongTensor,
2728
+ attention_mask: Optional[torch.Tensor] = None,
2729
+ position_ids: Optional[torch.Tensor] = None,
2730
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
2731
+ thinker_reply_part: Optional[torch.Tensor] = None,
2732
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2733
+ rope_deltas: Optional[torch.LongTensor] = None,
2734
+ use_cache: Optional[bool] = None,
2735
+ cache_position: Optional[torch.LongTensor] = None,
2736
+ input_text_ids: Optional[torch.LongTensor] = None,
2737
+ image_grid_thw: Optional[torch.LongTensor] = None,
2738
+ video_grid_thw: Optional[torch.LongTensor] = None,
2739
+ use_audio_in_video: Optional[bool] = None,
2740
+ audio_feature_lengths: Optional[torch.LongTensor] = None,
2741
+ video_second_per_grid: Optional[torch.LongTensor] = None,
2742
+ output_attentions: Optional[bool] = None,
2743
+ output_hidden_states: Optional[bool] = None,
2744
+ return_dict: Optional[bool] = None,
2745
+ ) -> Union[Tuple, Qwen2_5OmniTalkerCausalLMOutputWithPast]:
2746
+
2747
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2748
+ output_hidden_states = (
2749
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2750
+ )
2751
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2752
+
2753
+ if attention_mask is not None and position_ids is None:
2754
+ if cache_position is None or (cache_position is not None and cache_position[0] == 0):
2755
+ position_ids, rope_deltas = self.get_rope_index(
2756
+ input_text_ids,
2757
+ image_grid_thw,
2758
+ video_grid_thw,
2759
+ attention_mask,
2760
+ use_audio_in_video,
2761
+ audio_feature_lengths,
2762
+ video_second_per_grid,
2763
+ )
2764
+ assert inputs_embeds is not None
2765
+ inputs_embeds[:, -1, :] += self.get_input_embeddings()(
2766
+ torch.tensor([self.codec_bos_token], dtype=torch.long, device=inputs_embeds.device)
2767
+ )
2768
+ inputs_embeds[:, -2, :] += self.get_input_embeddings()(
2769
+ torch.tensor([self.codec_pad_token], dtype=torch.long, device=inputs_embeds.device)
2770
+ )
2771
+
2772
+ else:
2773
+ assert input_ids is not None
2774
+ batch_size, seq_length = input_ids.shape
2775
+ delta = (
2776
+ cache_position[0] + rope_deltas
2777
+ if cache_position is not None and rope_deltas is not None
2778
+ else torch.tensor(0, device=input_ids.device)
2779
+ )
2780
+ position_ids = torch.arange(seq_length, device=input_ids.device)
2781
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
2782
+ position_ids = position_ids.add(delta)
2783
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
2784
+
2785
+ if inputs_embeds is None:
2786
+ assert thinker_reply_part is not None
2787
+ # 1. 推理第 2 个以及之后的 token
2788
+ codec_embeds = self.get_input_embeddings()(input_ids)
2789
+ inputs_embeds = codec_embeds + thinker_reply_part[:, :1, :]
2790
+ if thinker_reply_part.shape[1] > 1:
2791
+ thinker_reply_part = thinker_reply_part[:, 1:, :]
2792
+
2793
+ talker_lm_input = self.thinker_to_talker_proj(inputs_embeds)
2794
+
2795
+ if attention_mask is not None:
2796
+ attention_mask = attention_mask.to(inputs_embeds.device)
2797
+
2798
+ outputs = self.model(
2799
+ attention_mask=attention_mask,
2800
+ position_ids=position_ids,
2801
+ past_key_values=past_key_values,
2802
+ inputs_embeds=talker_lm_input,
2803
+ use_cache=use_cache,
2804
+ output_attentions=output_attentions,
2805
+ output_hidden_states=output_hidden_states,
2806
+ return_dict=return_dict,
2807
+ )
2808
+
2809
+ hidden_states = outputs[0]
2810
+ logits = self.codec_head(hidden_states)
2811
+ logits = logits.float()
2812
+
2813
+ loss = None
2814
+
2815
+ if not return_dict:
2816
+ output = (logits,) + outputs[1:]
2817
+ return (loss,) + output if loss is not None else output
2818
+
2819
+ return Qwen2_5OmniTalkerCausalLMOutputWithPast(
2820
+ loss=loss,
2821
+ logits=logits,
2822
+ past_key_values=outputs.past_key_values,
2823
+ hidden_states=hidden_states,
2824
+ attentions=outputs.attentions,
2825
+ attention_mask=attention_mask,
2826
+ rope_deltas=rope_deltas,
2827
+ thinker_reply_part=thinker_reply_part,
2828
+ )
2829
+
2830
+ def _get_initial_cache_position(self, input_ids, model_kwargs):
2831
+ # Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily
2832
+ inputs_embeds = model_kwargs.pop("inputs_embeds")
2833
+ model_kwargs = super()._get_initial_cache_position(input_ids, model_kwargs)
2834
+ model_kwargs["inputs_embeds"] = inputs_embeds
2835
+ return model_kwargs
2836
+
2837
+ # prepare inputs for talker lm generation
2838
+ def prepare_inputs_for_generation(
2839
+ self,
2840
+ input_ids,
2841
+ input_text_ids,
2842
+ past_key_values=None,
2843
+ attention_mask=None,
2844
+ inputs_embeds=None,
2845
+ thinker_reply_part=None,
2846
+ cache_position=None,
2847
+ position_ids=None,
2848
+ use_cache=True,
2849
+ pixel_values=None,
2850
+ pixel_values_videos=None,
2851
+ image_grid_thw=None,
2852
+ video_grid_thw=None,
2853
+ input_audio_features=None,
2854
+ audio_feature_attention_mask=None,
2855
+ audio_feature_lengths=None,
2856
+ use_audio_in_video=False,
2857
+ video_second_per_grid=None,
2858
+ **kwargs,
2859
+ ):
2860
+ model_inputs = super().prepare_inputs_for_generation(
2861
+ input_ids,
2862
+ past_key_values,
2863
+ attention_mask,
2864
+ inputs_embeds,
2865
+ cache_position,
2866
+ use_cache=use_cache,
2867
+ thinker_reply_part=thinker_reply_part,
2868
+ input_text_ids=input_text_ids,
2869
+ image_grid_thw=image_grid_thw,
2870
+ video_grid_thw=video_grid_thw,
2871
+ use_audio_in_video=use_audio_in_video,
2872
+ audio_feature_lengths=audio_feature_lengths,
2873
+ video_second_per_grid=video_second_per_grid,
2874
+ **kwargs,
2875
+ )
2876
+
2877
+ model_inputs["position_ids"] = None
2878
+
2879
+ return model_inputs
2880
+
2881
+ def _update_model_kwargs_for_generation(
2882
+ self,
2883
+ outputs: ModelOutput,
2884
+ model_kwargs: Dict[str, Any],
2885
+ is_encoder_decoder: bool = False,
2886
+ num_new_tokens: int = 1,
2887
+ ) -> Dict[str, Any]:
2888
+ # update attention_mask
2889
+ if getattr(outputs, "attention_mask", None) is not None:
2890
+ model_kwargs["attention_mask"] = outputs.attention_mask
2891
+
2892
+ model_kwargs = super()._update_model_kwargs_for_generation(
2893
+ outputs, model_kwargs, is_encoder_decoder, num_new_tokens
2894
+ )
2895
+
2896
+ if getattr(outputs, "rope_deltas", None) is not None:
2897
+ model_kwargs["rope_deltas"] = outputs.rope_deltas
2898
+
2899
+ if getattr(outputs, "thinker_reply_part", None) is not None:
2900
+ model_kwargs["thinker_reply_part"] = outputs.thinker_reply_part
2901
+
2902
+ return model_kwargs
2903
+
2904
+
2905
+ # Using custom RoPE, will use LlamaRotaryEmbedding next version
2906
+ class RotaryEmbedding(nn.Module):
2907
+ def __init__(self, dim, base=10000):
2908
+ super().__init__()
2909
+
2910
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
2911
+ self.register_buffer("inv_freq", inv_freq)
2912
+
2913
+ def forward(self, x):
2914
+ batch_size, seq_len = x.shape[0], x.shape[1]
2915
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
2916
+ freqs = torch.einsum("i , j -> i j", t.type_as(self.inv_freq), self.inv_freq)
2917
+ freqs = torch.stack((freqs, freqs), dim=-1)
2918
+ freqs = freqs.reshape(*freqs.shape[:-2], -1)
2919
+ freqs = freqs.repeat(batch_size, *([1] * freqs.dim()))
2920
+
2921
+ return freqs.cos(), freqs.sin()
2922
+
2923
+
2924
+ class TDNNBlock(nn.Module):
2925
+ def __init__(
2926
+ self,
2927
+ in_channels,
2928
+ out_channels,
2929
+ kernel_size,
2930
+ dilation,
2931
+ ):
2932
+ super().__init__()
2933
+ self.conv = nn.Conv1d(
2934
+ in_channels=in_channels,
2935
+ out_channels=out_channels,
2936
+ kernel_size=kernel_size,
2937
+ dilation=dilation,
2938
+ padding="same",
2939
+ padding_mode="reflect",
2940
+ )
2941
+ self.activation = nn.ReLU()
2942
+
2943
+ def forward(self, x):
2944
+ return self.activation(self.conv(x))
2945
+
2946
+
2947
+ class Res2NetBlock(torch.nn.Module):
2948
+ """An implementation of Res2NetBlock w/ dilation.
2949
+
2950
+ Arguments
2951
+ ---------
2952
+ in_channels : int
2953
+ The number of channels expected in the input.
2954
+ out_channels : int
2955
+ The number of output channels.
2956
+ scale : int
2957
+ The scale of the Res2Net block.
2958
+ kernel_size: int
2959
+ The kernel size of the Res2Net block.
2960
+ dilation : int
2961
+ The dilation of the Res2Net block.
2962
+ """
2963
+
2964
+ def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
2965
+ super().__init__()
2966
+ assert in_channels % scale == 0
2967
+ assert out_channels % scale == 0
2968
+
2969
+ in_channel = in_channels // scale
2970
+ hidden_channel = out_channels // scale
2971
+
2972
+ self.blocks = nn.ModuleList(
2973
+ [
2974
+ TDNNBlock(
2975
+ in_channel,
2976
+ hidden_channel,
2977
+ kernel_size=kernel_size,
2978
+ dilation=dilation,
2979
+ )
2980
+ for i in range(scale - 1)
2981
+ ]
2982
+ )
2983
+ self.scale = scale
2984
+
2985
+ def forward(self, x):
2986
+ y = []
2987
+ for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
2988
+ if i == 0:
2989
+ y_i = x_i
2990
+ elif i == 1:
2991
+ y_i = self.blocks[i - 1](x_i)
2992
+ else:
2993
+ y_i = self.blocks[i - 1](x_i + y_i)
2994
+ y.append(y_i)
2995
+ y_p = torch.cat(y, dim=1)
2996
+ return y_p
2997
+
2998
+
2999
+ class SEBlock(nn.Module):
3000
+ """An implementation of squeeze-and-excitation block.
3001
+
3002
+ Arguments
3003
+ ---------
3004
+ in_channels : int
3005
+ The number of input channels.
3006
+ se_channels : int
3007
+ The number of output channels after squeeze.
3008
+ out_channels : int
3009
+ The number of output channels.
3010
+ """
3011
+
3012
+ def __init__(self, in_channels, se_channels, out_channels):
3013
+ super().__init__()
3014
+
3015
+ self.conv1 = nn.Conv1d(
3016
+ in_channels=in_channels,
3017
+ out_channels=se_channels,
3018
+ kernel_size=1,
3019
+ padding="same",
3020
+ padding_mode="reflect",
3021
+ )
3022
+ self.relu = nn.ReLU(inplace=True)
3023
+ self.conv2 = nn.Conv1d(
3024
+ in_channels=se_channels,
3025
+ out_channels=out_channels,
3026
+ kernel_size=1,
3027
+ padding="same",
3028
+ padding_mode="reflect",
3029
+ )
3030
+ self.sigmoid = nn.Sigmoid()
3031
+
3032
+ def forward(self, x):
3033
+ s = x.mean(dim=2, keepdim=True)
3034
+
3035
+ s = self.relu(self.conv1(s))
3036
+ s = self.sigmoid(self.conv2(s))
3037
+
3038
+ return s * x
3039
+
3040
+
3041
+ class AttentiveStatisticsPooling(nn.Module):
3042
+ """This class implements an attentive statistic pooling layer for each channel.
3043
+ It returns the concatenated mean and std of the input tensor.
3044
+
3045
+ Arguments
3046
+ ---------
3047
+ channels: int
3048
+ The number of input channels.
3049
+ attention_channels: int
3050
+ The number of attention channels.
3051
+ """
3052
+
3053
+ def __init__(self, channels, attention_channels=128):
3054
+ super().__init__()
3055
+
3056
+ self.eps = 1e-12
3057
+ self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
3058
+ self.tanh = nn.Tanh()
3059
+ self.conv = nn.Conv1d(
3060
+ in_channels=attention_channels,
3061
+ out_channels=channels,
3062
+ kernel_size=1,
3063
+ padding="same",
3064
+ padding_mode="reflect",
3065
+ )
3066
+
3067
+ def _length_to_mask(self, length, max_len=None, dtype=None, device=None):
3068
+ assert len(length.shape) == 1
3069
+
3070
+ if max_len is None:
3071
+ max_len = length.max().long().item() # using arange to generate mask
3072
+ mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
3073
+ len(length), max_len
3074
+ ) < length.unsqueeze(1)
3075
+
3076
+ mask = torch.as_tensor(mask, dtype=dtype, device=device)
3077
+ return mask
3078
+
3079
+ def _compute_statistics(self, x, m, dim=2):
3080
+ mean = (m * x).sum(dim)
3081
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps))
3082
+ return mean, std
3083
+
3084
+ def forward(self, x):
3085
+ """Calculates mean and std for a batch (input tensor).
3086
+
3087
+ Arguments
3088
+ ---------
3089
+ x : torch.Tensor
3090
+ Tensor of shape [N, C, L].
3091
+ """
3092
+ L = x.shape[-1]
3093
+
3094
+ lengths = torch.ones(x.shape[0], device=x.device)
3095
+
3096
+ # Make binary mask of shape [N, 1, L]
3097
+ mask = self._length_to_mask(lengths * L, max_len=L, dtype=x.dtype, device=x.device)
3098
+ mask = mask.unsqueeze(1)
3099
+
3100
+ # Expand the temporal context of the pooling layer by allowing the
3101
+ # self-attention to look at global properties of the utterance.
3102
+ total = mask.sum(dim=2, keepdim=True)
3103
+
3104
+ mean, std = self._compute_statistics(x, mask / total)
3105
+ mean = mean.unsqueeze(2).repeat(1, 1, L)
3106
+ std = std.unsqueeze(2).repeat(1, 1, L)
3107
+ attn = torch.cat([x, mean, std], dim=1)
3108
+
3109
+ # Apply layers
3110
+ attn = self.conv(self.tanh(self.tdnn(attn)))
3111
+
3112
+ # Filter out zero-paddings
3113
+ attn = attn.masked_fill(mask == 0, float("-inf"))
3114
+
3115
+ attn = F.softmax(attn, dim=2)
3116
+ mean, std = self._compute_statistics(x, attn)
3117
+ # Append mean and std of the batch
3118
+ pooled_stats = torch.cat((mean, std), dim=1)
3119
+ pooled_stats = pooled_stats.unsqueeze(2)
3120
+
3121
+ return pooled_stats
3122
+
3123
+
3124
+ class SERes2NetBlock(nn.Module):
3125
+ """An implementation of building block in ECAPA-TDNN, i.e.,
3126
+ TDNN-Res2Net-TDNN-SEBlock.
3127
+
3128
+ Arguments
3129
+ ----------
3130
+ out_channels: int
3131
+ The number of output channels.
3132
+ res2net_scale: int
3133
+ The scale of the Res2Net block.
3134
+ kernel_size: int
3135
+ The kernel size of the TDNN blocks.
3136
+ dilation: int
3137
+ The dilation of the Res2Net block.
3138
+ activation : torch class
3139
+ A class for constructing the activation layers.
3140
+ """
3141
+
3142
+ def __init__(
3143
+ self,
3144
+ in_channels,
3145
+ out_channels,
3146
+ res2net_scale=8,
3147
+ se_channels=128,
3148
+ kernel_size=1,
3149
+ dilation=1,
3150
+ ):
3151
+ super().__init__()
3152
+ self.out_channels = out_channels
3153
+ self.tdnn1 = TDNNBlock(
3154
+ in_channels,
3155
+ out_channels,
3156
+ kernel_size=1,
3157
+ dilation=1,
3158
+ )
3159
+ self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
3160
+ self.tdnn2 = TDNNBlock(
3161
+ out_channels,
3162
+ out_channels,
3163
+ kernel_size=1,
3164
+ dilation=1,
3165
+ )
3166
+ self.se_block = SEBlock(out_channels, se_channels, out_channels)
3167
+
3168
+ def forward(self, x):
3169
+ residual = x
3170
+
3171
+ x = self.tdnn1(x)
3172
+ x = self.res2net_block(x)
3173
+ x = self.tdnn2(x)
3174
+ x = self.se_block(x)
3175
+
3176
+ return x + residual
3177
+
3178
+
3179
+ class ECAPA_TDNN(torch.nn.Module):
3180
+
3181
+ def __init__(self, config: Qwen2_5OmniDiTConfig):
3182
+ super().__init__()
3183
+ assert len(config.enc_channels) == len(config.enc_kernel_sizes)
3184
+ assert len(config.enc_channels) == len(config.enc_dilations)
3185
+ self.channels = config.enc_channels
3186
+ self.blocks = nn.ModuleList()
3187
+
3188
+ # The initial TDNN layer
3189
+ self.blocks.append(
3190
+ TDNNBlock(
3191
+ config.mel_dim,
3192
+ config.enc_channels[0],
3193
+ config.enc_kernel_sizes[0],
3194
+ config.enc_dilations[0],
3195
+ )
3196
+ )
3197
+
3198
+ # SE-Res2Net layers
3199
+ for i in range(1, len(config.enc_channels) - 1):
3200
+ self.blocks.append(
3201
+ SERes2NetBlock(
3202
+ config.enc_channels[i - 1],
3203
+ config.enc_channels[i],
3204
+ res2net_scale=config.enc_res2net_scale,
3205
+ se_channels=config.enc_se_channels,
3206
+ kernel_size=config.enc_kernel_sizes[i],
3207
+ dilation=config.enc_dilations[i],
3208
+ )
3209
+ )
3210
+
3211
+ # Multi-layer feature aggregation
3212
+ self.mfa = TDNNBlock(
3213
+ config.enc_channels[-1],
3214
+ config.enc_channels[-1],
3215
+ config.enc_kernel_sizes[-1],
3216
+ config.enc_dilations[-1],
3217
+ )
3218
+
3219
+ # Attentive Statistical Pooling
3220
+ self.asp = AttentiveStatisticsPooling(
3221
+ config.enc_channels[-1],
3222
+ attention_channels=config.enc_attention_channels,
3223
+ )
3224
+
3225
+ # Final linear transformation
3226
+ self.fc = nn.Conv1d(
3227
+ in_channels=config.enc_channels[-1] * 2,
3228
+ out_channels=config.enc_dim,
3229
+ kernel_size=1,
3230
+ padding="same",
3231
+ padding_mode="reflect",
3232
+ )
3233
+
3234
+ def forward(self, x):
3235
+ """Returns the embedding vector.
3236
+
3237
+ Arguments
3238
+ ---------
3239
+ x : torch.Tensor
3240
+ Tensor of shape (batch, time, channel).
3241
+ """
3242
+ # Minimize transpose for efficiency
3243
+ x = x.transpose(1, 2)
3244
+
3245
+ xl = []
3246
+ for layer in self.blocks:
3247
+ x = layer(x)
3248
+ xl.append(x)
3249
+
3250
+ # Multi-layer feature aggregation
3251
+ x = torch.cat(xl[1:], dim=1)
3252
+ x = self.mfa(x)
3253
+
3254
+ # Attentive Statistical Pooling
3255
+ x = self.asp(x)
3256
+
3257
+ # Final linear transformation
3258
+ x = self.fc(x)
3259
+
3260
+ x = x.squeeze(-1)
3261
+ return x
3262
+
3263
+
3264
+ class InputEmbedding(nn.Module):
3265
+ def __init__(self, config: Qwen2_5OmniDiTConfig):
3266
+ super().__init__()
3267
+ self.proj = nn.Linear(
3268
+ config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim,
3269
+ config.hidden_size,
3270
+ )
3271
+ self.spk_encoder = ECAPA_TDNN(config)
3272
+
3273
+ def forward(self, x, spk, cond, code_embed, drop_audio_cond=False, code_embed_uncond=None, cfg=True):
3274
+ if cfg:
3275
+ x = torch.cat([x, x], dim=0)
3276
+ spk = torch.cat([spk, torch.zeros_like(spk)], dim=0)
3277
+ cond = torch.cat([cond, torch.zeros_like(cond)], dim=0)
3278
+ code_embed = torch.cat([code_embed, code_embed_uncond], dim=0)
3279
+ elif drop_audio_cond: # cfg for cond audio
3280
+ cond = torch.zeros_like(cond)
3281
+ spk = torch.zeros_like(spk)
3282
+ cond = self.spk_encoder(cond).unsqueeze(1).repeat(1, x.size(1), 1)
3283
+ x = self.proj(torch.cat((x, cond, code_embed, spk), dim=-1))
3284
+
3285
+ return x
3286
+
3287
+
3288
+ # Transformer backbone using DiT blocks
3289
+ class CodecEmbedding(nn.Module):
3290
+ def __init__(self, codec_num_embeds, codec_dim, repeats):
3291
+ super().__init__()
3292
+ self.repeats = repeats
3293
+ self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim)
3294
+
3295
+ def forward(self, code, drop_code=False):
3296
+ if drop_code:
3297
+ code = torch.zeros_like(code)
3298
+ code_embed = self.codec_embed(code)
3299
+
3300
+ code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1)
3301
+ return code_embed
3302
+
3303
+
3304
+ # AdaLayerNormZero
3305
+ # return with modulated x for attn input, and params for later mlp modulation
3306
+ class AdaLayerNormZero(nn.Module):
3307
+ def __init__(self, dim):
3308
+ super().__init__()
3309
+
3310
+ self.silu = nn.SiLU()
3311
+ self.linear = nn.Linear(dim, dim * 6)
3312
+
3313
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
3314
+
3315
+ def forward(self, x, emb=None):
3316
+ emb = self.linear(self.silu(emb))
3317
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
3318
+
3319
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
3320
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
3321
+
3322
+
3323
+ # AdaLayerNormZero for final layer
3324
+ # return only with modulated x for attn input, cuz no more mlp modulation
3325
+ class AdaLayerNormZero_Final(nn.Module):
3326
+ def __init__(self, dim):
3327
+ super().__init__()
3328
+
3329
+ self.silu = nn.SiLU()
3330
+ self.linear = nn.Linear(dim, dim * 2)
3331
+
3332
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
3333
+
3334
+ def forward(self, x, emb):
3335
+ emb = self.linear(self.silu(emb))
3336
+ scale, shift = torch.chunk(emb, 2, dim=1)
3337
+
3338
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
3339
+ return x
3340
+
3341
+
3342
+ # FeedForward
3343
+ class FeedForward(nn.Module):
3344
+ def __init__(self, dim, mult=4, dropout=0.0):
3345
+ super().__init__()
3346
+ inner_dim = int(dim * mult)
3347
+
3348
+ self.ff = nn.ModuleList(
3349
+ [
3350
+ nn.Linear(dim, inner_dim),
3351
+ nn.GELU(approximate="tanh"),
3352
+ nn.Dropout(dropout),
3353
+ nn.Linear(inner_dim, dim),
3354
+ ]
3355
+ )
3356
+
3357
+ def forward(self, x):
3358
+ for layer in self.ff:
3359
+ x = layer(x)
3360
+ return x
3361
+
3362
+
3363
+ # Modified from Llama with a different rotate function, will fixed in next release
3364
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
3365
+
3366
+ def rotate_half_codec(x):
3367
+ # x = rearrange(x, "... (d r) -> ... d r", r=2)
3368
+ x = x.reshape(*x.shape[:-1], -1, 2)
3369
+ x1, x2 = x.unbind(dim=-1)
3370
+ x = torch.stack((-x2, x1), dim=-1)
3371
+ return x.reshape(*x.shape[:-2], -1)
3372
+
3373
+ cos = cos.unsqueeze(unsqueeze_dim)
3374
+ sin = sin.unsqueeze(unsqueeze_dim)
3375
+ q_embed = (q * cos) + (rotate_half_codec(q) * sin)
3376
+ k_embed = (k * cos) + (rotate_half_codec(k) * sin)
3377
+ return q_embed, k_embed
3378
+
3379
+
3380
+ class DiTAttention(nn.Module):
3381
+ def __init__(self, config: Qwen2_5OmniDiTConfig):
3382
+ super().__init__()
3383
+
3384
+ self.config = config
3385
+ self.dim = config.hidden_size
3386
+ self.heads = config.num_attention_heads
3387
+ self.inner_dim = config.head_dim * config.num_attention_heads
3388
+ self.dropout = config.dropout
3389
+ self._attn_implementation = config._attn_implementation
3390
+ self.is_causal = False
3391
+
3392
+ self.to_q = nn.Linear(config.hidden_size, self.inner_dim)
3393
+ self.to_k = nn.Linear(config.hidden_size, self.inner_dim)
3394
+ self.to_v = nn.Linear(config.hidden_size, self.inner_dim)
3395
+
3396
+ self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)])
3397
+
3398
+ def forward(
3399
+ self,
3400
+ x, # noised input x
3401
+ rope=None, # rotary position embedding for x
3402
+ mask=None,
3403
+ ) -> torch.Tensor:
3404
+ batch_size = x.shape[0]
3405
+
3406
+ # `sample` projections.
3407
+ query = self.to_q(x)
3408
+ key = self.to_k(x)
3409
+ value = self.to_v(x)
3410
+
3411
+ # attention
3412
+ inner_dim = key.shape[-1]
3413
+ head_dim = inner_dim // self.heads
3414
+ query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
3415
+ key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
3416
+ value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
3417
+
3418
+ # apply rotary position embedding
3419
+ # Due to training process, only first head is applied with RoPE, will be fixed at next release
3420
+ cos, sin = rope
3421
+ query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin)
3422
+
3423
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation]
3424
+ x, _ = attention_interface(
3425
+ self,
3426
+ query,
3427
+ key,
3428
+ value,
3429
+ attention_mask=mask,
3430
+ is_causal=False,
3431
+ )
3432
+
3433
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
3434
+ # x = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=0.0, is_causal=False)
3435
+ x = x.reshape(batch_size, -1, self.heads * head_dim)
3436
+ x = x.to(query.dtype)
3437
+
3438
+ # linear proj
3439
+ x = self.to_out[0](x)
3440
+ # dropout
3441
+ x = self.to_out[1](x)
3442
+
3443
+ return x
3444
+
3445
+
3446
+ # time step conditioning embedding
3447
+ class SinusPositionEmbedding(nn.Module):
3448
+ def __init__(self, dim):
3449
+ super().__init__()
3450
+ self.dim = dim
3451
+
3452
+ def forward(self, x, scale=1000):
3453
+ device = x.device
3454
+ half_dim = self.dim // 2
3455
+ emb = math.log(10000) / (half_dim - 1)
3456
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
3457
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
3458
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
3459
+ return emb.type_as(x)
3460
+
3461
+
3462
+ class TimestepEmbedding(nn.Module):
3463
+ def __init__(self, dim, freq_embed_dim=256):
3464
+ super().__init__()
3465
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
3466
+ self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)])
3467
+
3468
+ def forward(self, timestep): # noqa: F821
3469
+ time_hidden = self.time_embed(timestep)
3470
+ time_hidden = time_hidden.to(timestep.dtype)
3471
+ for layer in self.time_mlp:
3472
+ time_hidden = layer(time_hidden) # b d
3473
+ return time_hidden
3474
+
3475
+
3476
+ class DiTBlock(nn.Module):
3477
+ def __init__(self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0):
3478
+ super().__init__()
3479
+ self.attn_norm = AdaLayerNormZero(config.hidden_size)
3480
+
3481
+ self.attn = DiTAttention(config)
3482
+ self.look_ahead_block = look_ahead_block
3483
+ self.look_backward_block = look_backward_block
3484
+ self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6)
3485
+ self.ff = FeedForward(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout)
3486
+
3487
+ def forward(self, x, t, rope=None, block_diff=None): # x: noised input, t: time embedding
3488
+ # pre-norm & modulation for attention input
3489
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
3490
+
3491
+ # attention
3492
+ attn_output = self.attn(
3493
+ x=norm,
3494
+ rope=rope,
3495
+ mask=(block_diff >= -float(self.look_backward_block)) & (block_diff <= float(self.look_ahead_block)),
3496
+ )
3497
+
3498
+ # process attention output for input x
3499
+ x = x + gate_msa.unsqueeze(1) * attn_output
3500
+
3501
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
3502
+ ff_output = self.ff(norm)
3503
+ x = x + gate_mlp.unsqueeze(1) * ff_output
3504
+
3505
+ return x
3506
+
3507
+
3508
+ class SnakeBeta(nn.Module):
3509
+
3510
+ def __init__(self, in_features, alpha=1.0):
3511
+ super().__init__()
3512
+ self.in_features = in_features
3513
+
3514
+ # initialize alpha
3515
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
3516
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
3517
+
3518
+ self.no_div_by_zero = 0.000000001
3519
+
3520
+ def forward(self, x):
3521
+ """
3522
+ Forward pass of the function.
3523
+ Applies the function to the input elementwise.
3524
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
3525
+ """
3526
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
3527
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
3528
+ alpha = torch.exp(alpha)
3529
+ beta = torch.exp(beta)
3530
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
3531
+
3532
+ return x
3533
+
3534
+
3535
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
3536
+ even = kernel_size % 2 == 0
3537
+ half_size = kernel_size // 2
3538
+
3539
+ # For kaiser window
3540
+ delta_f = 4 * half_width
3541
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
3542
+ if A > 50.0:
3543
+ beta = 0.1102 * (A - 8.7)
3544
+ elif A >= 21.0:
3545
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
3546
+ else:
3547
+ beta = 0.0
3548
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32)
3549
+
3550
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
3551
+ if even:
3552
+ time = torch.arange(-half_size, half_size) + 0.5
3553
+ else:
3554
+ time = torch.arange(kernel_size) - half_size
3555
+ if cutoff == 0:
3556
+ filter_ = torch.zeros_like(time)
3557
+ else:
3558
+ filter_ = 2 * cutoff * window * torch.sinc(2 * cutoff * time)
3559
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
3560
+ # of the constant component in the input signal.
3561
+ filter_ /= filter_.sum()
3562
+ filter = filter_.view(1, 1, kernel_size)
3563
+
3564
+ return filter
3565
+
3566
+
3567
+ class UpSample1d(nn.Module):
3568
+ def __init__(self, ratio=2, kernel_size=None):
3569
+ super().__init__()
3570
+ self.ratio = ratio
3571
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
3572
+ self.stride = ratio
3573
+ self.pad = self.kernel_size // ratio - 1
3574
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
3575
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
3576
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
3577
+ self.register_buffer("filter", filter, persistent=False)
3578
+
3579
+ # x: [B, C, T]
3580
+ def forward(self, x):
3581
+ _, C, _ = x.shape
3582
+
3583
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
3584
+ x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
3585
+ x = x[..., self.pad_left : -self.pad_right]
3586
+
3587
+ return x
3588
+
3589
+
3590
+ class DownSample1d(nn.Module):
3591
+ def __init__(self, ratio=2, kernel_size=None):
3592
+ super().__init__()
3593
+ cutoff = 0.5 / ratio
3594
+ half_width = 0.6 / ratio
3595
+ if cutoff < -0.0:
3596
+ raise ValueError("Minimum cutoff must be larger than zero.")
3597
+ if cutoff > 0.5:
3598
+ raise ValueError("A cutoff above 0.5 does not make sense.")
3599
+ self.kernel_size = kernel_size
3600
+ self.even = kernel_size % 2 == 0
3601
+ self.pad_left = kernel_size // 2 - int(self.even)
3602
+ self.pad_right = kernel_size // 2
3603
+ self.stride = ratio
3604
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
3605
+ self.register_buffer("filter", filter, persistent=False)
3606
+
3607
+ def forward(self, x):
3608
+ _, C, _ = x.shape
3609
+
3610
+ x = F.pad(x, (self.pad_left, self.pad_right), mode="replicate")
3611
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
3612
+
3613
+ return out
3614
+
3615
+
3616
+ class TorchActivation1d(nn.Module):
3617
+ def __init__(
3618
+ self,
3619
+ activation,
3620
+ up_ratio: int = 2,
3621
+ down_ratio: int = 2,
3622
+ up_kernel_size: int = 12,
3623
+ down_kernel_size: int = 12,
3624
+ ):
3625
+ super().__init__()
3626
+ self.up_ratio = up_ratio
3627
+ self.down_ratio = down_ratio
3628
+ self.act = activation
3629
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
3630
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
3631
+
3632
+ # x: [B,C,T]
3633
+ def forward(self, x):
3634
+ x = self.upsample(x)
3635
+ x = self.act(x)
3636
+ x = self.downsample(x)
3637
+
3638
+ return x
3639
+
3640
+
3641
+ class AMPBlock(torch.nn.Module):
3642
+ def __init__(
3643
+ self,
3644
+ channels,
3645
+ kernel_size=3,
3646
+ dilation=(1, 3, 5),
3647
+ ):
3648
+ super().__init__()
3649
+
3650
+ self.convs1 = nn.ModuleList(
3651
+ [
3652
+ nn.Conv1d(
3653
+ channels,
3654
+ channels,
3655
+ kernel_size,
3656
+ 1,
3657
+ dilation=dilation[0],
3658
+ padding=self._get_padding(kernel_size, dilation[0]),
3659
+ ),
3660
+ nn.Conv1d(
3661
+ channels,
3662
+ channels,
3663
+ kernel_size,
3664
+ 1,
3665
+ dilation=dilation[1],
3666
+ padding=self._get_padding(kernel_size, dilation[1]),
3667
+ ),
3668
+ nn.Conv1d(
3669
+ channels,
3670
+ channels,
3671
+ kernel_size,
3672
+ 1,
3673
+ dilation=dilation[2],
3674
+ padding=self._get_padding(kernel_size, dilation[2]),
3675
+ ),
3676
+ ]
3677
+ )
3678
+
3679
+ self.convs2 = nn.ModuleList(
3680
+ [
3681
+ nn.Conv1d(
3682
+ channels,
3683
+ channels,
3684
+ kernel_size,
3685
+ 1,
3686
+ dilation=1,
3687
+ padding=self._get_padding(kernel_size, 1),
3688
+ ),
3689
+ nn.Conv1d(
3690
+ channels,
3691
+ channels,
3692
+ kernel_size,
3693
+ 1,
3694
+ dilation=1,
3695
+ padding=self._get_padding(kernel_size, 1),
3696
+ ),
3697
+ nn.Conv1d(
3698
+ channels,
3699
+ channels,
3700
+ kernel_size,
3701
+ 1,
3702
+ dilation=1,
3703
+ padding=self._get_padding(kernel_size, 1),
3704
+ ),
3705
+ ]
3706
+ )
3707
+
3708
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
3709
+
3710
+ self.activations = nn.ModuleList(
3711
+ [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)]
3712
+ )
3713
+
3714
+ def _get_padding(self, kernel_size, dilation=1):
3715
+ return int((kernel_size * dilation - dilation) / 2)
3716
+
3717
+ def forward(self, x):
3718
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
3719
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
3720
+ xt = a1(x)
3721
+ xt = c1(xt)
3722
+ xt = a2(xt)
3723
+ xt = c2(xt)
3724
+ x = xt + x
3725
+
3726
+ return x
3727
+
3728
+
3729
+ class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel):
3730
+ config_class = Qwen2_5OmniBigVGANConfig
3731
+
3732
+ def __init__(self, config: Qwen2_5OmniBigVGANConfig):
3733
+ super().__init__(config)
3734
+
3735
+ self.num_kernels = len(config.resblock_kernel_sizes)
3736
+ self.num_upsamples = len(config.upsample_rates)
3737
+
3738
+ # pre conv
3739
+ self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3)
3740
+
3741
+ # transposed conv-based upsamplers. does not apply anti-aliasing
3742
+ self.ups = nn.ModuleList()
3743
+ for i, (u, k) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
3744
+ self.ups.append(
3745
+ nn.ModuleList(
3746
+ [
3747
+ ConvTranspose1d(
3748
+ config.upsample_initial_channel // (2**i),
3749
+ config.upsample_initial_channel // (2 ** (i + 1)),
3750
+ k,
3751
+ u,
3752
+ padding=(k - u) // 2,
3753
+ )
3754
+ ]
3755
+ )
3756
+ )
3757
+
3758
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
3759
+ self.resblocks = nn.ModuleList()
3760
+ for i in range(len(self.ups)):
3761
+ ch = config.upsample_initial_channel // (2 ** (i + 1))
3762
+ for j, (k, d) in enumerate(zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes)):
3763
+ self.resblocks.append(AMPBlock(ch, k, d))
3764
+
3765
+ # post conv
3766
+ self.activation_post = TorchActivation1d(activation=SnakeBeta(ch))
3767
+
3768
+ self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
3769
+
3770
+ def _normalize(self, S, max_abs_value, min_db):
3771
+ return torch.clamp(
3772
+ (2 * max_abs_value) * ((S - min_db) / (-min_db)) - max_abs_value, -max_abs_value, max_abs_value
3773
+ )
3774
+
3775
+ def _amp_to_db(self, x, min_level_db):
3776
+ min_level = np.exp(min_level_db / 20 * np.log(10))
3777
+ min_level = torch.ones_like(x) * min_level
3778
+ return 20 * torch.log10(torch.maximum(min_level, x))
3779
+
3780
+ def apm_to_db(self, apm_mel):
3781
+ mel_spec = torch.exp(apm_mel)
3782
+
3783
+ mel_spec = self._amp_to_db(mel_spec, -115) - 20
3784
+ mel_spec = self._normalize(mel_spec, 1, -115)
3785
+
3786
+ return mel_spec
3787
+
3788
+ def forward(self, apm_mel):
3789
+ mel_spec = self.apm_to_db(apm_mel)
3790
+ # pre conv
3791
+ hidden = self.conv_pre(mel_spec)
3792
+
3793
+ for i in range(self.num_upsamples):
3794
+ # upsampling
3795
+ for i_up in range(len(self.ups[i])):
3796
+ ups_i = cast(nn.Sequential, self.ups[i])
3797
+ hidden = ups_i[i_up](hidden)
3798
+ # AMP blocks
3799
+ xs = None
3800
+ for j in range(self.num_kernels):
3801
+ if xs is None:
3802
+ xs = self.resblocks[i * self.num_kernels + j](hidden)
3803
+ else:
3804
+ xs += self.resblocks[i * self.num_kernels + j](hidden)
3805
+ assert xs is not None
3806
+ hidden = xs / self.num_kernels
3807
+
3808
+ # post conv
3809
+ hidden = self.activation_post(hidden)
3810
+ hidden = self.conv_post(hidden)
3811
+ audio = torch.clamp(hidden, min=-1.0, max=1.0) # bound the output to [-1, 1]
3812
+
3813
+ return audio.squeeze().cpu()
3814
+
3815
+
3816
+ class ODESolverRK4:
3817
+ def __init__(self, func, y0):
3818
+ self.func = func
3819
+ self.y0 = y0
3820
+
3821
+ self._one_third = 1 / 3
3822
+ self._two_thirds = 2 / 3
3823
+
3824
+ def _rk4_alt_step_func(self, func, t0, dt, t1, y0, f0=None):
3825
+ k1 = f0
3826
+ if k1 is None:
3827
+ k1 = func(t0, y0)
3828
+ k2 = func(t0 + dt * self._one_third, y0 + dt * k1 * self._one_third)
3829
+ k3 = func(t0 + dt * self._two_thirds, y0 + dt * (k2 - k1 * self._one_third))
3830
+ k4 = func(t1, y0 + dt * (k1 - k2 + k3))
3831
+ return (k1 + 3 * (k2 + k3) + k4) * dt * 0.125
3832
+
3833
+ def _step_func(self, func, t0, dt, t1, y0):
3834
+ f0 = func(t0, y0)
3835
+ return self._rk4_alt_step_func(func, t0, dt, t1, y0, f0=f0), f0
3836
+
3837
+ def _linear_interp(self, t0, t1, y0, y1, t):
3838
+ if t == t0:
3839
+ return y0
3840
+ if t == t1:
3841
+ return y1
3842
+ slope = (t - t0) / (t1 - t0)
3843
+ return y0 + slope * (y1 - y0)
3844
+
3845
+ def integrate(self, t):
3846
+ solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
3847
+ solution[0] = self.y0
3848
+
3849
+ j = 1
3850
+ y0 = self.y0
3851
+ for t0, t1 in zip(t[:-1], t[1:]):
3852
+ dt = t1 - t0
3853
+ dy, f0 = self._step_func(self.func, t0, dt, t1, y0)
3854
+ y1 = y0 + dy
3855
+
3856
+ while j < len(t) and t1 >= t[j]:
3857
+ solution[j] = self._linear_interp(t0, t1, y0, y1, t[j])
3858
+ j += 1
3859
+ y0 = y1
3860
+
3861
+ return solution
3862
+
3863
+
3864
+ class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel):
3865
+ config_class = Qwen2_5OmniDiTConfig
3866
+ _no_split_modules = ["DiTBlock"]
3867
+
3868
+ def __init__(self, config: Qwen2_5OmniDiTConfig):
3869
+ super().__init__(config)
3870
+ self.mel_dim = config.mel_dim
3871
+ self.repeats = config.repeats
3872
+ self.time_embed = TimestepEmbedding(config.hidden_size)
3873
+
3874
+ self.text_embed = CodecEmbedding(config.num_embeds, config.emb_dim, config.repeats)
3875
+ self.input_embed = InputEmbedding(config)
3876
+
3877
+ self.rotary_embed = RotaryEmbedding(config.head_dim)
3878
+ # self.rotary_embed = Qwen2_5OmniDiTRotaryEmbedding(config)
3879
+
3880
+ self.hidden_size = config.hidden_size
3881
+ self.layers = config.num_hidden_layers
3882
+ self.block_size = config.block_size
3883
+ self.num_attention_heads = config.num_attention_heads
3884
+
3885
+ self.transformer_blocks = nn.ModuleList()
3886
+ for i in range(config.num_hidden_layers):
3887
+ self.transformer_blocks.append(
3888
+ DiTBlock(
3889
+ config,
3890
+ look_ahead_block=1 if i in config.look_ahead_layers else 0,
3891
+ look_backward_block=1 if i in config.look_backward_layers else 0,
3892
+ )
3893
+ )
3894
+
3895
+ self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
3896
+ self.proj_out = nn.Linear(config.hidden_size, config.mel_dim)
3897
+
3898
+ def _create_block_diff(self, x):
3899
+ batch, seq_len = x.shape[0], x.shape[1]
3900
+ block_indices = torch.arange(seq_len, device=x.device) // self.block_size # [seq_length]
3901
+
3902
+ block_i = block_indices.unsqueeze(1) # [seq_length, 1]
3903
+ block_j = block_indices.unsqueeze(0) # [1, seq_length]
3904
+
3905
+ block_diff = block_j - block_i # (n, n)
3906
+
3907
+ return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len)
3908
+
3909
+ def forward(
3910
+ self,
3911
+ x, # nosied input audio
3912
+ cond, # masked cond audio
3913
+ spk, # spk embedding
3914
+ code, # code
3915
+ time, # time step # noqa: F821 F722
3916
+ drop_audio_cond=False, # cfg for cond audio
3917
+ drop_code=False, # cfg for code
3918
+ cfg=True,
3919
+ ):
3920
+ batch = x.shape[0]
3921
+ if time.ndim == 0:
3922
+ time = time.repeat(batch)
3923
+
3924
+ # t: conditioning time, c: context (code + masked cond audio), x: noised input audio
3925
+ t = self.time_embed(time)
3926
+ code_embed = self.text_embed(code, drop_code=False if cfg else drop_code)
3927
+ code_embed_uncond = self.text_embed(code, drop_code=True) if cfg else None
3928
+ hidden = self.input_embed(
3929
+ x,
3930
+ spk,
3931
+ cond,
3932
+ code_embed,
3933
+ drop_audio_cond=drop_audio_cond,
3934
+ code_embed_uncond=code_embed_uncond,
3935
+ cfg=cfg,
3936
+ )
3937
+
3938
+ # rope = self.rotary_embed(x, torch.arange(seq_len, device=x.device).repeat(batch, 1))
3939
+ rope = self.rotary_embed(hidden)
3940
+
3941
+ block_diff = self._create_block_diff(hidden)
3942
+
3943
+ for block in self.transformer_blocks:
3944
+ hidden = block(hidden, t, rope=rope, block_diff=block_diff)
3945
+
3946
+ hidden = self.norm_out(hidden, t)
3947
+ output = self.proj_out(hidden)
3948
+
3949
+ return output
3950
+
3951
+ @torch.no_grad()
3952
+ def sample(
3953
+ self,
3954
+ cond,
3955
+ ref_mel,
3956
+ code,
3957
+ steps=10,
3958
+ cfg_strength=0.5,
3959
+ sway_sampling_coef=-1.0,
3960
+ ):
3961
+ y_all = torch.randn([1, 30000, self.mel_dim], dtype=ref_mel.dtype)
3962
+ max_duration = code.shape[1] * self.repeats
3963
+ y0 = y_all[:, :max_duration].to(code.device)
3964
+ batch = ref_mel.shape[0]
3965
+ cond = cond.unsqueeze(1).repeat(1, max_duration, 1)
3966
+ assert batch == 1, "only support batch size = 1 currently"
3967
+
3968
+ def fn(t, x):
3969
+ if cfg_strength < 1e-5:
3970
+ pred = self(x=x, spk=cond, cond=ref_mel, code=code, time=t, drop_audio_cond=False, drop_code=False)
3971
+ return pred
3972
+
3973
+ out_put = self(x=x, code=code, spk=cond, cond=ref_mel, time=t, cfg=True)
3974
+ pred, null_pred = torch.chunk(out_put, 2, dim=0)
3975
+
3976
+ return pred + (pred - null_pred) * cfg_strength
3977
+
3978
+ t_start = 0
3979
+ t = torch.linspace(t_start, 1, steps, device=code.device, dtype=cond.dtype)
3980
+ if sway_sampling_coef is not None:
3981
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
3982
+
3983
+ solver = ODESolverRK4(func=fn, y0=y0)
3984
+ trajectory = solver.integrate(t)
3985
+
3986
+ generated = trajectory[-1]
3987
+ generated_mel_spec = generated.permute(0, 2, 1)
3988
+ return generated_mel_spec
3989
+
3990
+
3991
+ @add_start_docstrings(
3992
+ (
3993
+ "The full Qwen2.5Omni Token2Wav model. Consists a DiT model take speech"
3994
+ " tokens as input and predict mel spectrogram and a BigVGAN vocoder take"
3995
+ " mel spectrogram as input and predict waveform."
3996
+ ),
3997
+ QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniToken2WavConfig"),
3998
+ )
3999
+ class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel):
4000
+ config_class = Qwen2_5OmniToken2WavConfig
4001
+ base_model_prefix = "model"
4002
+ _no_split_modules = ["Qwen2_5OmniToken2WavDiTModel", "Qwen2_5OmniToken2WavBigVGANModel"]
4003
+
4004
+ def __init__(self, config: Qwen2_5OmniToken2WavConfig):
4005
+ super().__init__(config)
4006
+ attn_impl = config._attn_implementation
4007
+ if config._attn_implementation == "flash_attention_2":
4008
+ logger.warning_once(
4009
+ "Qwen2_5OmniToken2WavModel must inference with fp32, but "
4010
+ "flash_attention_2 only supports fp16 and bf16, "
4011
+ "attention implementation of Qwen2_5OmniToken2WavModel will fallback to sdpa."
4012
+ )
4013
+ attn_impl = "sdpa"
4014
+ elif config._attn_implementation == "eager":
4015
+ logger.warning_once(
4016
+ "Qwen2_5OmniToken2WavModel does not support eager attention implementation, " "fall back to sdpa"
4017
+ )
4018
+ attn_impl = "sdpa"
4019
+ self.code2wav_dit_model = Qwen2_5OmniToken2WavDiTModel._from_config(
4020
+ config.dit_config, attn_implementation=attn_impl
4021
+ )
4022
+ self.code2wav_bigvgan_model = Qwen2_5OmniToken2WavBigVGANModel._from_config(
4023
+ config.bigvgan_config, attn_implementation=attn_impl
4024
+ )
4025
+
4026
+ def forward(
4027
+ self,
4028
+ code,
4029
+ cond,
4030
+ ref_mel,
4031
+ steps=10,
4032
+ cfg_strength=0.5,
4033
+ sway_sampling_coef=-1.0,
4034
+ **kwargs,
4035
+ ):
4036
+ generated_mel = self.code2wav_dit_model.sample(
4037
+ cond,
4038
+ ref_mel,
4039
+ code,
4040
+ steps=steps,
4041
+ cfg_strength=cfg_strength,
4042
+ sway_sampling_coef=sway_sampling_coef,
4043
+ )
4044
+ waveform = self.code2wav_bigvgan_model(generated_mel)
4045
+ return waveform
4046
+
4047
+
4048
+ @add_start_docstrings(
4049
+ """""",
4050
+ QWEN2_5OMNI_START_DOCSTRING.format(config_class=Qwen2_5OmniConfig),
4051
+ )
4052
+ class Qwen2_5OmniModel(Qwen2_5OmniPreTrainedModel):
4053
+ config_class = Qwen2_5OmniConfig
4054
+ _no_split_modules = [
4055
+ "Qwen2_5OmniTalkerForConditionalGeneration",
4056
+ "Qwen2_5OmniToken2WavModel",
4057
+ ]
4058
+
4059
+ def __init__(self, config):
4060
+ super().__init__(config)
4061
+
4062
+ self.thinker = Qwen2_5OmniThinkerForConditionalGeneration(config.thinker_config)
4063
+
4064
+ self.has_talker = config.enable_audio_output
4065
+ self.speaker_map = {}
4066
+ if config.enable_audio_output:
4067
+ self.enable_talker()
4068
+
4069
+ def enable_talker(self):
4070
+ self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config)
4071
+ self.token2wav = Qwen2_5OmniToken2WavModel(self.config.token2wav_config)
4072
+ self.token2wav.float()
4073
+ self.has_talker = True
4074
+
4075
+ def load_speakers(self, path):
4076
+ for key, value in torch.load(path).items():
4077
+ self.speaker_map[key] = value
4078
+ logger.info("Speaker {} loaded".format(list(self.speaker_map.keys())))
4079
+
4080
+ def disable_talker(self):
4081
+ if hasattr(self, "talker"):
4082
+ del self.talker
4083
+ if hasattr(self, "token2wav"):
4084
+ del self.token2wav
4085
+ self.has_talker = False
4086
+
4087
+ @classmethod
4088
+ def can_generate(cls) -> bool:
4089
+ return True
4090
+
4091
+ @classmethod
4092
+ def from_pretrained(
4093
+ cls,
4094
+ pretrained_model_name_or_path,
4095
+ *model_args,
4096
+ config=None,
4097
+ cache_dir=None,
4098
+ ignore_mismatched_sizes=False,
4099
+ force_download=False,
4100
+ local_files_only=False,
4101
+ token=None,
4102
+ revision="main",
4103
+ use_safetensors=None,
4104
+ weights_only=True,
4105
+ **kwargs,
4106
+ ):
4107
+ model = super().from_pretrained(
4108
+ pretrained_model_name_or_path,
4109
+ *model_args,
4110
+ config=config,
4111
+ cache_dir=cache_dir,
4112
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
4113
+ force_download=force_download,
4114
+ local_files_only=local_files_only,
4115
+ token=token,
4116
+ revision=revision,
4117
+ use_safetensors=use_safetensors,
4118
+ weights_only=weights_only,
4119
+ **kwargs,
4120
+ )
4121
+ spk_path = cached_file(
4122
+ pretrained_model_name_or_path,
4123
+ "spk_dict.pt",
4124
+ subfolder=kwargs.pop("subfolder", None),
4125
+ cache_dir=kwargs.pop("cache_dir", None),
4126
+ force_download=kwargs.pop("force_download", False),
4127
+ proxies=kwargs.pop("proxies", None),
4128
+ resume_download=kwargs.pop("resume_download", None),
4129
+ local_files_only=kwargs.pop("local_files_only", False),
4130
+ token=kwargs.pop("use_auth_token", None),
4131
+ revision=kwargs.pop("revision", None),
4132
+ )
4133
+ if spk_path is None:
4134
+ raise ValueError(f"""{pretrained_model_name_or_path}/{spk_path} not exists""")
4135
+ model.load_speakers(spk_path)
4136
+
4137
+ return model
4138
+
4139
+ @torch.no_grad()
4140
+ def generate(
4141
+ self,
4142
+ input_ids: Optional[torch.Tensor] = None,
4143
+ spk: str = "Chelsie",
4144
+ use_audio_in_video: bool = False,
4145
+ return_audio: Optional[bool] = None,
4146
+ thinker_max_new_tokens: int = 1024,
4147
+ talker_max_new_tokens: int = 4096,
4148
+ talker_do_sample: bool = True,
4149
+ talker_top_k: int = 40,
4150
+ talker_top_p: float = 0.8,
4151
+ talker_temperature: float = 0.9,
4152
+ talker_eos_token_id: list[int] = [8292, 8294],
4153
+ talker_repetition_penalty: float = 1.05,
4154
+ **kwargs,
4155
+ ):
4156
+ if spk not in self.speaker_map:
4157
+ raise ValueError(f"{spk} is not availible, availible speakers: {self.speaker_map.keys()}")
4158
+ if return_audio and not self.has_talker:
4159
+ raise ValueError(
4160
+ "Cannot use talker when talker module not initalized. Use `enable_talker` "
4161
+ "method or set enable_talker in config to enable talker."
4162
+ )
4163
+ if return_audio is None:
4164
+ return_audio = self.has_talker
4165
+ assert input_ids is not None
4166
+ if input_ids.shape[0] != 1 and return_audio:
4167
+ raise NotImplementedError("Qwen2.5-Omni currently does not support batched inference with audio output")
4168
+ shared_kwargs = {"use_audio_in_video": use_audio_in_video}
4169
+ thinker_kwargs = {
4170
+ "max_new_tokens": thinker_max_new_tokens,
4171
+ }
4172
+ talker_kwargs: dict[str, Union[torch.Tensor, Any]] = {
4173
+ "max_new_tokens": talker_max_new_tokens,
4174
+ "do_sample": talker_do_sample,
4175
+ "top_k": talker_top_k,
4176
+ "top_p": talker_top_p,
4177
+ "temperature": talker_temperature,
4178
+ "eos_token_id": talker_eos_token_id,
4179
+ "repetition_penalty": talker_repetition_penalty,
4180
+ }
4181
+ token2wav_kwargs = {}
4182
+
4183
+ for key, value in kwargs.items():
4184
+ if key.startswith("thinker_"):
4185
+ thinker_kwargs[key[len("thinker_") :]] = value
4186
+ elif key.startswith("talker_"):
4187
+ talker_kwargs[key[len("talker_") :]] = value
4188
+ elif key.startswith("token2wav_"):
4189
+ token2wav_kwargs[key[len("token2wav_") :]] = value
4190
+ # Process special input values
4191
+ elif key == "feature_attention_mask":
4192
+ thinker_kwargs[key] = value
4193
+ talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1)
4194
+ elif key == "input_features" or key == "attention_mask":
4195
+ thinker_kwargs[key] = value
4196
+ # Put other key to shared kwargs
4197
+ else:
4198
+ shared_kwargs[key] = value
4199
+ # Merge kwargs
4200
+ for key, value in shared_kwargs.items():
4201
+ if key not in thinker_kwargs:
4202
+ thinker_kwargs[key] = value
4203
+ if key not in talker_kwargs:
4204
+ talker_kwargs[key] = value
4205
+ if key not in token2wav_kwargs:
4206
+ token2wav_kwargs[key] = value
4207
+ speaker_params = self.speaker_map[spk]
4208
+
4209
+ # 1. Generate from thinker module
4210
+ thinker_result = self.thinker.generate(
4211
+ input_ids=input_ids,
4212
+ return_dict_in_generate=True,
4213
+ output_hidden_states=True,
4214
+ **thinker_kwargs,
4215
+ )
4216
+ if not (return_audio and self.has_talker):
4217
+ return thinker_result.sequences
4218
+
4219
+ # 2. Generate speech tokens from talker module
4220
+ thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device)
4221
+ thinker_token_embeds = [x[0].to(self.talker.device) for x in thinker_result.hidden_states]
4222
+ thinker_hidden_states = [x[1][-1].to(self.talker.device) for x in thinker_result.hidden_states]
4223
+
4224
+ talker_text_bos_token = speaker_params["bos_token"]
4225
+ talker_input_text_ids = torch.cat(
4226
+ [
4227
+ input_ids.to(self.talker.device),
4228
+ torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.talker.device),
4229
+ thinker_generate_ids[:, :1],
4230
+ ],
4231
+ dim=-1,
4232
+ )
4233
+
4234
+ talker_input_ids = torch.cat(
4235
+ [
4236
+ torch.full_like(input_ids, fill_value=self.talker.codec_mask_token, device=self.talker.device),
4237
+ torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=self.talker.device),
4238
+ torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=self.talker.device),
4239
+ ],
4240
+ dim=1,
4241
+ )
4242
+
4243
+ thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
4244
+ talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0]
4245
+ talker_inputs_embeds = torch.cat(
4246
+ [
4247
+ talker_inputs_embeds,
4248
+ self.thinker.get_input_embeddings()(
4249
+ torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device)
4250
+ ).to(self.talker.device),
4251
+ thinker_reply_part[:, :1, :],
4252
+ ],
4253
+ dim=1,
4254
+ )
4255
+
4256
+ thinker_reply_part = torch.cat(
4257
+ [
4258
+ thinker_reply_part[:, 1:, :],
4259
+ self.thinker.get_input_embeddings()(
4260
+ torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device)
4261
+ ).to(self.talker.device),
4262
+ self.thinker.get_input_embeddings()(
4263
+ torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device)
4264
+ ).to(self.talker.device),
4265
+ ],
4266
+ dim=1,
4267
+ )
4268
+
4269
+ talker_attention_mask = torch.cat(
4270
+ [kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1
4271
+ ).to(self.talker.device)
4272
+
4273
+ talker_result = self.talker.generate(
4274
+ input_ids=talker_input_ids,
4275
+ input_text_ids=talker_input_text_ids,
4276
+ thinker_reply_part=thinker_reply_part,
4277
+ inputs_embeds=talker_inputs_embeds,
4278
+ attention_mask=talker_attention_mask,
4279
+ suppress_tokens=[self.talker.codec_bos_token],
4280
+ **{k: (v.to(self.talker.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()},
4281
+ )
4282
+ talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1]
4283
+
4284
+ # 3. Generate wavs from code
4285
+ if self.token2wav.dtype != torch.float:
4286
+ self.token2wav.float()
4287
+ wav = self.token2wav(
4288
+ talker_generate_codes.to(self.token2wav.device),
4289
+ cond=speaker_params["cond"].to(self.token2wav.device).float(),
4290
+ ref_mel=speaker_params["ref_mel"].to(self.token2wav.device).float(),
4291
+ **token2wav_kwargs,
4292
+ )
4293
+
4294
+ return thinker_result.sequences, wav.float()
4295
+
4296
+
4297
+ __all__ = [
4298
+ "Qwen2_5OmniModel",
4299
+ "Qwen2_5OmniThinkerModel",
4300
+ "Qwen2_5OmniThinkerForConditionalGeneration",
4301
+ "Qwen2_5OmniTalkerModel",
4302
+ "Qwen2_5OmniTalkerForConditionalGeneration",
4303
+ "Qwen2_5OmniToken2WavDiTModel",
4304
+ "Qwen2_5OmniToken2WavBigVGANModel",
4305
+ "Qwen2_5OmniToken2WavModel",
4306
+ "Qwen2_5OmniPreTrainedModel",
4307
+ "Qwen2_5OmniPreTrainedModelForConditionalGeneration",
4308
+ ]