crfm-helm 0.5.7__py3-none-any.whl → 0.5.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/METADATA +7 -77
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/RECORD +315 -282
- helm/benchmark/adaptation/adapter_spec.py +10 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +11 -3
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +11 -8
- helm/benchmark/annotation/aci_bench_annotator.py +11 -22
- helm/benchmark/annotation/alrage_annotator.py +90 -0
- helm/benchmark/annotation/chw_care_plan_annotator.py +10 -21
- helm/benchmark/annotation/dischargeme_annotator.py +11 -22
- helm/benchmark/annotation/med_dialog_annotator.py +11 -22
- helm/benchmark/annotation/medalign_annotator.py +11 -22
- helm/benchmark/annotation/medi_qa_annotator.py +11 -22
- helm/benchmark/annotation/medication_qa_annotator.py +11 -22
- helm/benchmark/annotation/mental_health_annotator.py +11 -22
- helm/benchmark/annotation/mimic_bhc_annotator.py +11 -22
- helm/benchmark/annotation/mimic_rrs_annotator.py +11 -22
- helm/benchmark/annotation/model_as_judge.py +23 -18
- helm/benchmark/annotation/mtsamples_procedures_annotator.py +11 -22
- helm/benchmark/annotation/mtsamples_replicate_annotator.py +11 -22
- helm/benchmark/annotation/starr_patient_instructions_annotator.py +11 -22
- helm/benchmark/metrics/air_bench_metrics.py +3157 -1
- helm/benchmark/metrics/alrage_metric.py +35 -0
- helm/benchmark/metrics/basic_metrics.py +267 -2
- helm/benchmark/metrics/bbq_metrics.py +12 -0
- helm/benchmark/metrics/classification_metrics.py +19 -1
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +12 -1
- helm/benchmark/metrics/dry_run_metrics.py +30 -1
- helm/benchmark/metrics/efficiency_metrics.py +74 -0
- helm/benchmark/metrics/ehr_sql_metrics.py +57 -1
- helm/benchmark/metrics/evaluate_reference_metrics.py +311 -0
- helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +13 -1
- helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +13 -1
- helm/benchmark/metrics/ifeval_metrics.py +13 -1
- helm/benchmark/metrics/instruction_following_critique_metrics.py +41 -1
- helm/benchmark/metrics/kpi_edgar_metrics.py +21 -0
- helm/benchmark/metrics/language_modeling_metrics.py +13 -1
- helm/benchmark/metrics/live_qa_metrics.py +13 -1
- helm/benchmark/metrics/llm_jury_metrics.py +13 -1
- helm/benchmark/metrics/medcalc_bench_metrics.py +14 -1
- helm/benchmark/metrics/medec_metrics.py +25 -2
- helm/benchmark/metrics/metric.py +25 -0
- helm/benchmark/metrics/mimiciv_billing_code_metrics.py +32 -1
- helm/benchmark/metrics/omni_math_metrics.py +13 -1
- helm/benchmark/metrics/safety_metrics.py +13 -1
- helm/benchmark/metrics/seahelm_metrics.py +14 -1
- helm/benchmark/metrics/summac/model_summac.py +2 -2
- helm/benchmark/metrics/summarization_metrics.py +129 -1
- helm/benchmark/metrics/toxicity_metrics.py +31 -1
- helm/benchmark/metrics/ultra_suite_asr_classification_metrics.py +52 -0
- helm/benchmark/metrics/wildbench_metrics.py +21 -1
- helm/benchmark/presentation/run_display.py +13 -3
- helm/benchmark/presentation/run_entry.py +2 -2
- helm/benchmark/presentation/schema.py +5 -22
- helm/benchmark/presentation/summarize.py +180 -11
- helm/benchmark/presentation/taxonomy_info.py +20 -0
- helm/benchmark/run.py +1 -1
- helm/benchmark/run_expander.py +4 -0
- helm/benchmark/run_specs/arabic_run_specs.py +140 -16
- helm/benchmark/run_specs/bluex_run_specs.py +1 -1
- helm/benchmark/run_specs/classic_run_specs.py +2 -2
- helm/benchmark/run_specs/long_context_run_specs.py +2 -2
- helm/benchmark/run_specs/medhelm/__init__.py +0 -0
- helm/benchmark/run_specs/medhelm/benchmark_config.py +219 -0
- helm/benchmark/run_specs/medhelm_run_specs.py +362 -52
- helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +6 -2
- helm/benchmark/scenarios/aci_bench_scenario.py +23 -0
- helm/benchmark/scenarios/air_bench_scenario.py +21 -0
- helm/benchmark/scenarios/alrage_scenario.py +54 -0
- helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +23 -1
- helm/benchmark/scenarios/anthropic_red_team_scenario.py +12 -1
- helm/benchmark/scenarios/arabic_exams_scenario.py +114 -0
- helm/benchmark/scenarios/arabic_mmlu_scenario.py +8 -4
- helm/benchmark/scenarios/aratrust_scenario.py +19 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification_scenario.py +24 -54
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +19 -48
- helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +22 -61
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +21 -29
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +21 -60
- helm/benchmark/scenarios/babi_qa_scenario.py +15 -0
- helm/benchmark/scenarios/banking77_scenario.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +15 -0
- helm/benchmark/scenarios/best_chatgpt_prompts.yaml +473 -0
- helm/benchmark/scenarios/bird_sql_scenario.py +18 -0
- helm/benchmark/scenarios/bluex_scenario.py +6 -2
- helm/benchmark/scenarios/bold_scenario.py +15 -0
- helm/benchmark/scenarios/boolq_scenario.py +20 -0
- helm/benchmark/scenarios/chw_care_plan_scenario.py +23 -0
- helm/benchmark/scenarios/civil_comments_scenario.py +13 -0
- helm/benchmark/scenarios/clear_scenario.py +23 -0
- helm/benchmark/scenarios/cleva_scenario.py +479 -0
- helm/benchmark/scenarios/code_scenario.py +28 -0
- helm/benchmark/scenarios/commonsense_scenario.py +32 -0
- helm/benchmark/scenarios/compositional_instructions.yaml +70 -0
- helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +21 -0
- helm/benchmark/scenarios/copyright_scenario.py +35 -1
- helm/benchmark/scenarios/cti_to_mitre_scenario.py +21 -0
- helm/benchmark/scenarios/czech_bank_qa_scenario.py +18 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +22 -1
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +23 -1
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +22 -1
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +21 -1
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +13 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +13 -1
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +13 -1
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +13 -1
- helm/benchmark/scenarios/dischargeme_scenario.py +24 -0
- helm/benchmark/scenarios/disinformation_scenario.py +22 -0
- helm/benchmark/scenarios/dyck_language_scenario.py +15 -0
- helm/benchmark/scenarios/ehrshot_scenario.py +22 -0
- helm/benchmark/scenarios/enem_challenge_scenario.py +19 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +14 -0
- helm/benchmark/scenarios/entity_matching_scenario.py +14 -0
- helm/benchmark/scenarios/fin_qa_scenario.py +20 -0
- helm/benchmark/scenarios/financebench_scenario.py +21 -0
- helm/benchmark/scenarios/financial_phrasebank_scenario.py +21 -0
- helm/benchmark/scenarios/gold_commodity_news_scenario.py +21 -0
- helm/benchmark/scenarios/gpqa_scenario.py +18 -0
- helm/benchmark/scenarios/grammar_scenario.py +20 -1
- helm/benchmark/scenarios/gsm_scenario.py +21 -0
- helm/benchmark/scenarios/harm_bench_gcg_transfer_scenario.py +12 -1
- helm/benchmark/scenarios/harm_bench_scenario.py +12 -1
- helm/benchmark/scenarios/headqa_scenario.py +22 -0
- helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +13 -0
- helm/benchmark/scenarios/ice_scenario.py +21 -1
- helm/benchmark/scenarios/ifeval_scenario.py +18 -0
- helm/benchmark/scenarios/imdb_scenario.py +15 -0
- helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +21 -0
- helm/benchmark/scenarios/infinite_bench_en_sum_scenario.py +19 -0
- helm/benchmark/scenarios/koala_scenario.py +21 -1
- helm/benchmark/scenarios/kpi_edgar_scenario.py +21 -0
- helm/benchmark/scenarios/legal_contract_summarization_scenario.py +20 -0
- helm/benchmark/scenarios/legal_summarization_scenario.py +50 -0
- helm/benchmark/scenarios/legal_support_scenario.py +13 -0
- helm/benchmark/scenarios/legalbench_scenario.py +19 -0
- helm/benchmark/scenarios/lex_glue_scenario.py +11 -0
- helm/benchmark/scenarios/lextreme_scenario.py +11 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +14 -0
- helm/benchmark/scenarios/madinah_qa_scenario.py +73 -0
- helm/benchmark/scenarios/math_scenario.py +33 -0
- helm/benchmark/scenarios/mbzuai_human_translated_arabic_mmlu.py +68 -0
- helm/benchmark/scenarios/med_dialog_scenario.py +32 -1
- helm/benchmark/scenarios/med_mcqa_scenario.py +14 -0
- helm/benchmark/scenarios/med_qa_scenario.py +20 -0
- helm/benchmark/scenarios/medalign_scenario.py +23 -0
- helm/benchmark/scenarios/medbullets_scenario.py +22 -0
- helm/benchmark/scenarios/medcalc_bench_scenario.py +22 -0
- helm/benchmark/scenarios/medec_scenario.py +23 -0
- helm/benchmark/scenarios/medhallu_scenario.py +23 -0
- helm/benchmark/scenarios/medhelm/__init__.py +0 -0
- helm/benchmark/scenarios/medhelm/judges.yaml +14 -0
- helm/benchmark/scenarios/medhelm_configurable_scenario.py +101 -0
- helm/benchmark/scenarios/medi_qa_scenario.py +24 -1
- helm/benchmark/scenarios/medication_qa_scenario.py +31 -1
- helm/benchmark/scenarios/mental_health_scenario.py +23 -0
- helm/benchmark/scenarios/mimic_bhc_scenario.py +24 -0
- helm/benchmark/scenarios/mimic_rrs_scenario.py +23 -0
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +22 -0
- helm/benchmark/scenarios/mmlu_pro_scenario.py +18 -0
- helm/benchmark/scenarios/mmlu_scenario.py +21 -0
- helm/benchmark/scenarios/msmarco_scenario.py +30 -0
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +22 -0
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +22 -0
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +20 -0
- helm/benchmark/scenarios/narrativeqa_scenario.py +19 -0
- helm/benchmark/scenarios/natural_qa_scenario.py +32 -0
- helm/benchmark/scenarios/omni_math_scenario.py +18 -0
- helm/benchmark/scenarios/open_assistant_scenario.py +22 -0
- helm/benchmark/scenarios/openai_mrcr_scenario.py +15 -0
- helm/benchmark/scenarios/pubmed_qa_scenario.py +22 -0
- helm/benchmark/scenarios/quac_scenario.py +14 -0
- helm/benchmark/scenarios/race_based_med_scenario.py +23 -0
- helm/benchmark/scenarios/raft_scenario.py +15 -0
- helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +14 -1
- helm/benchmark/scenarios/ruler_qa_scenarios.py +40 -0
- helm/benchmark/scenarios/scenario.py +31 -0
- helm/benchmark/scenarios/seahelm_scenario.py +348 -0
- helm/benchmark/scenarios/self_instruct_scenario.py +29 -1
- helm/benchmark/scenarios/shc_bmt_scenario.py +22 -0
- helm/benchmark/scenarios/shc_cdi_scenario.py +20 -0
- helm/benchmark/scenarios/shc_conf_scenario.py +23 -0
- helm/benchmark/scenarios/shc_ent_scenario.py +21 -0
- helm/benchmark/scenarios/shc_gip_scenario.py +20 -0
- helm/benchmark/scenarios/shc_privacy_scenario.py +22 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +22 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +23 -0
- helm/benchmark/scenarios/shc_sequoia_scenario.py +21 -0
- helm/benchmark/scenarios/simple_safety_tests_scenario.py +12 -1
- helm/benchmark/scenarios/situation_prompts.yaml +49 -0
- helm/benchmark/scenarios/spider_scenario.py +18 -0
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +22 -0
- helm/benchmark/scenarios/summarization_scenario.py +37 -0
- helm/benchmark/scenarios/synthetic_efficiency_scenario.py +22 -1
- helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +13 -0
- helm/benchmark/scenarios/test_alrage_scenario.py +23 -0
- helm/benchmark/scenarios/test_arabic_exams_scenario.py +21 -0
- helm/benchmark/scenarios/test_aratrust_scenario.py +1 -1
- helm/benchmark/scenarios/test_bluex_scenario.py +2 -2
- helm/benchmark/scenarios/thai_exam_scenario.py +95 -0
- helm/benchmark/scenarios/the_pile_scenario.py +13 -1
- helm/benchmark/scenarios/truthful_qa_scenario.py +14 -0
- helm/benchmark/scenarios/twitter_aae_scenario.py +20 -1
- helm/benchmark/scenarios/vicuna_scenario.py +21 -1
- helm/benchmark/scenarios/wikifact_scenario.py +20 -0
- helm/benchmark/scenarios/wildbench_scenario.py +18 -0
- helm/benchmark/scenarios/wmt_14_scenario.py +19 -0
- helm/benchmark/static/schema_arabic.yaml +55 -12
- helm/benchmark/static/schema_long_context.yaml +11 -30
- helm/benchmark/static/schema_medhelm.yaml +36 -0
- helm/benchmark/static/schema_slp.yaml +219 -0
- helm/benchmark/static_build/assets/audio-table-Dn5NMMeJ.png +0 -0
- helm/benchmark/static_build/assets/index-oIeiQW2g.css +1 -0
- helm/benchmark/static_build/assets/index-qOFpOyHb.js +10 -0
- helm/benchmark/static_build/assets/react-BteFIppM.js +85 -0
- helm/benchmark/static_build/assets/recharts-DxuQtTOs.js +97 -0
- helm/benchmark/static_build/assets/tremor-DR4fE7ko.js +10 -0
- helm/benchmark/static_build/index.html +5 -6
- helm/clients/ai21_client.py +2 -0
- helm/clients/aleph_alpha_client.py +2 -0
- helm/clients/anthropic_client.py +7 -1
- helm/clients/audio_language/diva_llama_client.py +2 -0
- helm/clients/audio_language/llama_omni/arguments.py +61 -0
- helm/clients/audio_language/llama_omni/constants.py +9 -0
- helm/clients/audio_language/llama_omni/conversation.py +213 -0
- helm/clients/audio_language/llama_omni/model/__init__.py +0 -0
- helm/clients/audio_language/llama_omni/model/builder.py +88 -0
- helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py +190 -0
- helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py +118 -0
- helm/clients/audio_language/llama_omni/model/omni_speech_arch.py +249 -0
- helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py +27 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/generation.py +622 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py +104 -0
- helm/clients/audio_language/llama_omni/model/speech_projector/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py +27 -0
- helm/clients/audio_language/llama_omni/preprocess.py +295 -0
- helm/clients/audio_language/llama_omni/utils.py +202 -0
- helm/clients/audio_language/llama_omni_client.py +2 -1
- helm/clients/audio_language/qwen2_5_omni_client.py +2 -1
- helm/clients/audio_language/qwen2_audiolm_client.py +2 -1
- helm/clients/audio_language/qwen_audiolm_client.py +2 -1
- helm/clients/audio_language/qwen_omni/configuration_qwen2_5_omni.py +519 -0
- helm/clients/audio_language/qwen_omni/modeling_qwen2_5_omni.py +4308 -0
- helm/clients/audio_language/qwen_omni/processing_qwen2_5_omni.py +270 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/__init__.py +0 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/__init__.py +8 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/audio_process.py +56 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/vision_process.py +380 -0
- helm/clients/bedrock_client.py +2 -0
- helm/clients/cohere_client.py +3 -0
- helm/clients/google_client.py +2 -0
- helm/clients/http_model_client.py +2 -0
- helm/clients/huggingface_client.py +2 -1
- helm/clients/ibm_client.py +3 -1
- helm/clients/image_generation/adobe_vision_client.py +2 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +2 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +1 -1
- helm/clients/image_generation/cogview2_client.py +2 -1
- helm/clients/image_generation/dalle2_client.py +2 -0
- helm/clients/image_generation/dalle_mini_client.py +2 -1
- helm/clients/image_generation/deep_floyd_client.py +2 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +2 -1
- helm/clients/image_generation/lexica_client.py +2 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +2 -2
- helm/clients/image_generation/mindalle_client.py +2 -1
- helm/clients/image_generation/together_image_generation_client.py +2 -0
- helm/clients/megatron_client.py +2 -0
- helm/clients/mistral_client.py +2 -0
- helm/clients/moderation_api_client.py +2 -0
- helm/clients/openai_client.py +36 -20
- helm/clients/openai_responses_client.py +27 -3
- helm/clients/openrouter_client.py +31 -0
- helm/clients/palmyra_client.py +2 -1
- helm/clients/reka_client.py +2 -1
- helm/clients/stanfordhealthcare_azure_openai_client.py +2 -2
- helm/clients/stanfordhealthcare_http_model_client.py +2 -0
- helm/clients/test_openrouter_client.py +69 -0
- helm/clients/together_client.py +52 -11
- helm/clients/vertexai_client.py +12 -2
- helm/clients/vision_language/huggingface_vision2seq_client.py +2 -1
- helm/clients/vision_language/huggingface_vlm_client.py +2 -0
- helm/clients/vision_language/idefics_client.py +2 -1
- helm/clients/vision_language/open_flamingo_client.py +2 -1
- helm/clients/vision_language/paligemma_client.py +2 -1
- helm/clients/vision_language/palmyra_vision_client.py +2 -0
- helm/clients/vision_language/qwen2_vlm_client.py +2 -1
- helm/clients/vision_language/qwen_vlm_client.py +2 -1
- helm/clients/writer_client.py +2 -0
- helm/common/hierarchical_logger.py +20 -0
- helm/common/optional_dependencies.py +1 -1
- helm/common/test_general.py +4 -0
- helm/config/model_deployments.yaml +300 -1
- helm/config/model_metadata.yaml +302 -9
- helm/config/tokenizer_configs.yaml +92 -4
- helm/proxy/example_queries.py +8 -8
- helm/proxy/server.py +2 -1
- helm/proxy/static/index.css +4 -0
- helm/proxy/static/index.js +7 -1
- helm/benchmark/metrics/aci_bench_metrics.py +0 -14
- helm/benchmark/metrics/chw_care_plan_metrics.py +0 -14
- helm/benchmark/metrics/dischargeme_metrics.py +0 -14
- helm/benchmark/metrics/med_dialog_metrics.py +0 -14
- helm/benchmark/metrics/medalign_metrics.py +0 -14
- helm/benchmark/metrics/medi_qa_metrics.py +0 -14
- helm/benchmark/metrics/medication_qa_metrics.py +0 -14
- helm/benchmark/metrics/mental_health_metrics.py +0 -14
- helm/benchmark/metrics/mimic_bhc_metrics.py +0 -14
- helm/benchmark/metrics/mimic_rrs_metrics.py +0 -14
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +0 -14
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +0 -14
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +0 -14
- helm/benchmark/static_build/assets/index-b9779128.css +0 -1
- helm/benchmark/static_build/assets/index-e439d5e1.js +0 -10
- helm/benchmark/static_build/assets/react-f82877fd.js +0 -85
- helm/benchmark/static_build/assets/recharts-4037aff0.js +0 -97
- helm/benchmark/static_build/assets/tremor-38a10867.js +0 -10
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/WHEEL +0 -0
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/licenses/LICENSE +0 -0
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/top_level.txt +0 -0
- /helm/benchmark/static_build/assets/{air-overview-d2e6c49f.png → air-overview-DpBbyagA.png} +0 -0
- /helm/benchmark/static_build/assets/{crfm-logo-74391ab8.png → crfm-logo-Du4T1uWZ.png} +0 -0
- /helm/benchmark/static_build/assets/{heim-logo-3e5e3aa4.png → heim-logo-BJtQlEbV.png} +0 -0
- /helm/benchmark/static_build/assets/{helm-logo-simple-2ed5400b.png → helm-logo-simple-DzOhNN41.png} +0 -0
- /helm/benchmark/static_build/assets/{helm-safety-2907a7b6.png → helm-safety-COfndXuS.png} +0 -0
- /helm/benchmark/static_build/assets/{helmhero-28e90f4d.png → helmhero-D9TvmJsp.png} +0 -0
- /helm/benchmark/static_build/assets/{medhelm-overview-eac29843.png → medhelm-overview-CND0EIsy.png} +0 -0
- /helm/benchmark/static_build/assets/{medhelm-v1-overview-3ddfcd65.png → medhelm-v1-overview-Cu2tphBB.png} +0 -0
- /helm/benchmark/static_build/assets/{overview-74aea3d8.png → overview-BwypNWnk.png} +0 -0
- /helm/benchmark/static_build/assets/{process-flow-bd2eba96.png → process-flow-DWDJC733.png} +0 -0
- /helm/benchmark/static_build/assets/{vhelm-aspects-1437d673.png → vhelm-aspects-NiDQofvP.png} +0 -0
- /helm/benchmark/static_build/assets/{vhelm-framework-a1ca3f3f.png → vhelm-framework-NxJE4fdA.png} +0 -0
- /helm/benchmark/static_build/assets/{vhelm-model-8afb7616.png → vhelm-model-ypCL5Yvq.png} +0 -0
|
@@ -0,0 +1,4308 @@
|
|
|
1
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
2
|
+
# This file was automatically generated from src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py.
|
|
3
|
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
4
|
+
# the file from the modular. If any change should be done, please apply the change to the
|
|
5
|
+
# modular_qwen2_5_omni.py file directly. One of our CI enforces this.
|
|
6
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
7
|
+
# coding=utf-8
|
|
8
|
+
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
#
|
|
11
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
12
|
+
# you may not use this file except in compliance with the License.
|
|
13
|
+
# You may obtain a copy of the License at
|
|
14
|
+
#
|
|
15
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
16
|
+
#
|
|
17
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
18
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
19
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
20
|
+
# See the License for the specific language governing permissions and
|
|
21
|
+
# limitations under the License.
|
|
22
|
+
|
|
23
|
+
import math
|
|
24
|
+
import operator
|
|
25
|
+
from dataclasses import dataclass
|
|
26
|
+
from itertools import accumulate
|
|
27
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
|
28
|
+
|
|
29
|
+
import numpy as np
|
|
30
|
+
import torch
|
|
31
|
+
import torch.nn as nn
|
|
32
|
+
import torch.nn.functional as F
|
|
33
|
+
from torch.nn import ConvTranspose1d, Parameter
|
|
34
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
35
|
+
|
|
36
|
+
from transformers.activations import ACT2FN
|
|
37
|
+
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, SlidingWindowCache, StaticCache
|
|
38
|
+
from transformers.generation import GenerationMixin
|
|
39
|
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
40
|
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput
|
|
41
|
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
45
|
+
except ModuleNotFoundError as e:
|
|
46
|
+
handle_module_not_found_error(e, ["audiolm"])
|
|
47
|
+
from transformers.utils import (
|
|
48
|
+
add_start_docstrings,
|
|
49
|
+
is_flash_attn_2_available,
|
|
50
|
+
is_flash_attn_greater_or_equal_2_10,
|
|
51
|
+
logging,
|
|
52
|
+
)
|
|
53
|
+
from transformers.utils.hub import cached_file
|
|
54
|
+
from helm.clients.audio_language.qwen_omni.configuration_qwen2_5_omni import (
|
|
55
|
+
Qwen2_5OmniAudioEncoderConfig,
|
|
56
|
+
Qwen2_5OmniBigVGANConfig,
|
|
57
|
+
Qwen2_5OmniConfig,
|
|
58
|
+
Qwen2_5OmniDiTConfig,
|
|
59
|
+
Qwen2_5OmniTalkerConfig,
|
|
60
|
+
Qwen2_5OmniTextConfig,
|
|
61
|
+
Qwen2_5OmniThinkerConfig,
|
|
62
|
+
Qwen2_5OmniToken2WavConfig,
|
|
63
|
+
Qwen2_5OmniVisionEncoderConfig,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
if is_flash_attn_2_available():
|
|
68
|
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
|
|
69
|
+
from flash_attn.layers.rotary import apply_rotary_emb
|
|
70
|
+
else:
|
|
71
|
+
flash_attn_varlen_func = None
|
|
72
|
+
apply_rotary_emb = None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
if is_flash_attn_2_available():
|
|
76
|
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
77
|
+
else:
|
|
78
|
+
flash_attn_varlen_func = None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
logger = logging.get_logger(__name__)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# @add_start_docstrings(
|
|
85
|
+
# "The bare Qwen2.5Omni Model outputting raw hidden-states without any specific head on top.",
|
|
86
|
+
# QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniConfig"),
|
|
87
|
+
# )
|
|
88
|
+
class Qwen2_5OmniPreTrainedModel(PreTrainedModel):
|
|
89
|
+
config_class: Any = Qwen2_5OmniConfig
|
|
90
|
+
base_model_prefix = "model"
|
|
91
|
+
supports_gradient_checkpointing = True
|
|
92
|
+
_skip_keys_device_placement = "past_key_values"
|
|
93
|
+
_supports_flash_attn_2 = True
|
|
94
|
+
_supports_sdpa = True
|
|
95
|
+
_supports_cache_class = True
|
|
96
|
+
_supports_static_cache = True
|
|
97
|
+
|
|
98
|
+
def _init_weights(self, module):
|
|
99
|
+
# important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only
|
|
100
|
+
# inference and fine-tuning - so the proper init weights code has been removed
|
|
101
|
+
std = self.config.init_std if hasattr(self.config, "init_std") else 0.02
|
|
102
|
+
|
|
103
|
+
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d)):
|
|
104
|
+
module.weight.data.normal_(mean=0.0, std=std)
|
|
105
|
+
if module.bias is not None:
|
|
106
|
+
module.bias.data.zero_()
|
|
107
|
+
elif isinstance(module, nn.Embedding):
|
|
108
|
+
module.weight.data.normal_(mean=0.0, std=std)
|
|
109
|
+
if module.padding_idx is not None:
|
|
110
|
+
module.weight.data[module.padding_idx].zero_()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel):
|
|
114
|
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
115
|
+
self,
|
|
116
|
+
attention_mask: torch.Tensor,
|
|
117
|
+
sequence_length: int,
|
|
118
|
+
target_length: int,
|
|
119
|
+
dtype: torch.dtype,
|
|
120
|
+
device: torch.device,
|
|
121
|
+
min_dtype: float,
|
|
122
|
+
cache_position: torch.Tensor,
|
|
123
|
+
batch_size: int,
|
|
124
|
+
):
|
|
125
|
+
if attention_mask is not None and attention_mask.dim() == 4:
|
|
126
|
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
127
|
+
causal_mask = attention_mask
|
|
128
|
+
else:
|
|
129
|
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
130
|
+
if sequence_length != 1:
|
|
131
|
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
132
|
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
|
133
|
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
134
|
+
if attention_mask is not None:
|
|
135
|
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
136
|
+
mask_length = attention_mask.shape[-1]
|
|
137
|
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
|
138
|
+
padding_mask = padding_mask == 0
|
|
139
|
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
140
|
+
padding_mask, min_dtype
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
return causal_mask
|
|
144
|
+
|
|
145
|
+
def get_input_embeddings(self):
|
|
146
|
+
return self.model.get_input_embeddings()
|
|
147
|
+
|
|
148
|
+
def set_input_embeddings(self, value):
|
|
149
|
+
self.model.set_input_embeddings(value)
|
|
150
|
+
|
|
151
|
+
def get_llm_pos_ids_for_vision(
|
|
152
|
+
self,
|
|
153
|
+
start_idx: int,
|
|
154
|
+
vision_idx: int,
|
|
155
|
+
spatial_merge_size: int,
|
|
156
|
+
t_index: torch.Tensor,
|
|
157
|
+
grid_hs: torch.Tensor,
|
|
158
|
+
grid_ws: torch.Tensor,
|
|
159
|
+
):
|
|
160
|
+
llm_pos_ids_list = []
|
|
161
|
+
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
|
|
162
|
+
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
|
|
163
|
+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten()
|
|
164
|
+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten()
|
|
165
|
+
t_index_p = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long()
|
|
166
|
+
_llm_pos_ids = torch.stack([t_index_p, h_index, w_index])
|
|
167
|
+
llm_pos_ids_list.append(_llm_pos_ids + start_idx) # + 1 ) # 12.09 by malinhan
|
|
168
|
+
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
|
|
169
|
+
return llm_pos_ids
|
|
170
|
+
|
|
171
|
+
def get_chunked_index(self, llm_pos_ids, t_ntoken_per_chunk, st_idx):
|
|
172
|
+
def _iter():
|
|
173
|
+
i, start_idx = 0, 0 # skip bos token
|
|
174
|
+
current_chunk = 1
|
|
175
|
+
while i < llm_pos_ids.shape[1]: # skip eos token
|
|
176
|
+
if llm_pos_ids[0][i] - st_idx >= current_chunk * t_ntoken_per_chunk:
|
|
177
|
+
yield (start_idx, i)
|
|
178
|
+
start_idx = i
|
|
179
|
+
current_chunk += 1
|
|
180
|
+
i += 1
|
|
181
|
+
yield (start_idx, llm_pos_ids.shape[1])
|
|
182
|
+
|
|
183
|
+
return list(_iter())
|
|
184
|
+
|
|
185
|
+
def get_rope_index(
|
|
186
|
+
self,
|
|
187
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
188
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
189
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
190
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
191
|
+
use_audio_in_video: Optional[bool] = False,
|
|
192
|
+
audio_seqlens: Optional[torch.Tensor] = None,
|
|
193
|
+
second_per_grids: Optional[torch.Tensor] = None,
|
|
194
|
+
):
|
|
195
|
+
spatial_merge_size = self.spatial_merge_size
|
|
196
|
+
image_token_id = self.config.image_token_index
|
|
197
|
+
video_token_id = self.config.video_token_index
|
|
198
|
+
audio_token_id = self.config.audio_token_index
|
|
199
|
+
vision_start_token_id = self.config.vision_start_token_id
|
|
200
|
+
audio_start_token_id = self.config.audio_start_token_id
|
|
201
|
+
position_id_per_seconds = self.config.position_id_per_seconds
|
|
202
|
+
seconds_per_chunk = self.config.seconds_per_chunk
|
|
203
|
+
|
|
204
|
+
mrope_position_deltas = []
|
|
205
|
+
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
|
206
|
+
total_input_ids = input_ids
|
|
207
|
+
if attention_mask is None:
|
|
208
|
+
attention_mask = torch.ones_like(total_input_ids)
|
|
209
|
+
position_ids = torch.ones(
|
|
210
|
+
3,
|
|
211
|
+
input_ids.shape[0],
|
|
212
|
+
input_ids.shape[1],
|
|
213
|
+
dtype=input_ids.dtype,
|
|
214
|
+
device=input_ids.device,
|
|
215
|
+
)
|
|
216
|
+
image_idx, video_idx, audio_idx = 0, 0, 0
|
|
217
|
+
attention_mask = attention_mask.to(total_input_ids.device)
|
|
218
|
+
for i, input_ids_p in enumerate(total_input_ids):
|
|
219
|
+
input_ids_p = input_ids_p[attention_mask[i] == 1]
|
|
220
|
+
image_nums, video_nums, audio_nums = 0, 0, 0
|
|
221
|
+
vision_start_indices = torch.argwhere(input_ids_p == vision_start_token_id).squeeze(1)
|
|
222
|
+
vision_tokens = input_ids_p[vision_start_indices + 1]
|
|
223
|
+
audio_nums = int(torch.sum(input_ids_p == audio_start_token_id).item())
|
|
224
|
+
image_nums = (vision_tokens == image_token_id).sum()
|
|
225
|
+
video_nums = (
|
|
226
|
+
(vision_tokens == audio_start_token_id).sum()
|
|
227
|
+
if use_audio_in_video
|
|
228
|
+
else (vision_tokens == video_token_id).sum()
|
|
229
|
+
)
|
|
230
|
+
input_tokens = input_ids_p.tolist()
|
|
231
|
+
llm_pos_ids_list: list = []
|
|
232
|
+
st = 0
|
|
233
|
+
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums
|
|
234
|
+
multimodal_nums = (
|
|
235
|
+
image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums
|
|
236
|
+
)
|
|
237
|
+
for _ in range(multimodal_nums):
|
|
238
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
239
|
+
if image_token_id in input_tokens and remain_images > 0:
|
|
240
|
+
ed_image = input_tokens.index(image_token_id, st)
|
|
241
|
+
else:
|
|
242
|
+
ed_image = len(input_tokens) + 1
|
|
243
|
+
if video_token_id in input_tokens and remain_videos > 0:
|
|
244
|
+
ed_video = input_tokens.index(video_token_id, st)
|
|
245
|
+
else:
|
|
246
|
+
ed_video = len(input_tokens) + 1
|
|
247
|
+
if audio_token_id in input_tokens and remain_audios > 0:
|
|
248
|
+
ed_audio = input_tokens.index(audio_token_id, st)
|
|
249
|
+
else:
|
|
250
|
+
ed_audio = len(input_tokens) + 1
|
|
251
|
+
min_ed = min(ed_image, ed_video, ed_audio)
|
|
252
|
+
if min_ed == ed_audio and audio_seqlens is not None:
|
|
253
|
+
text_len = min_ed - st - 1
|
|
254
|
+
if text_len != 0:
|
|
255
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
256
|
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
257
|
+
|
|
258
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
259
|
+
bos_len = 1
|
|
260
|
+
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
|
|
261
|
+
|
|
262
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
263
|
+
audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1
|
|
264
|
+
llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
|
|
265
|
+
llm_pos_ids_list.append(llm_pos_ids)
|
|
266
|
+
|
|
267
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
268
|
+
eos_len = 1
|
|
269
|
+
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
|
|
270
|
+
|
|
271
|
+
st += text_len + bos_len + audio_len + eos_len
|
|
272
|
+
audio_idx += 1
|
|
273
|
+
remain_audios -= 1
|
|
274
|
+
|
|
275
|
+
elif min_ed == ed_image and image_grid_thw is not None:
|
|
276
|
+
text_len = min_ed - st - 1
|
|
277
|
+
if text_len != 0:
|
|
278
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
279
|
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
280
|
+
|
|
281
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
282
|
+
bos_len = 1
|
|
283
|
+
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
|
|
284
|
+
|
|
285
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
286
|
+
grid_t = image_grid_thw[image_idx][0]
|
|
287
|
+
grid_hs = image_grid_thw[:, 1]
|
|
288
|
+
grid_ws = image_grid_thw[:, 2]
|
|
289
|
+
t_index = (torch.arange(grid_t.item()) * 1 * position_id_per_seconds).long()
|
|
290
|
+
llm_pos_ids = self.get_llm_pos_ids_for_vision(
|
|
291
|
+
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
|
292
|
+
)
|
|
293
|
+
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
|
|
294
|
+
llm_pos_ids_list.append(llm_pos_ids)
|
|
295
|
+
|
|
296
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
297
|
+
eos_len = 1
|
|
298
|
+
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
|
|
299
|
+
|
|
300
|
+
st += text_len + bos_len + image_len + eos_len
|
|
301
|
+
image_idx += 1
|
|
302
|
+
remain_images -= 1
|
|
303
|
+
|
|
304
|
+
elif (
|
|
305
|
+
min_ed == ed_video
|
|
306
|
+
and not use_audio_in_video
|
|
307
|
+
and video_grid_thw is not None
|
|
308
|
+
and second_per_grids is not None
|
|
309
|
+
):
|
|
310
|
+
text_len = min_ed - st - 1
|
|
311
|
+
if text_len != 0:
|
|
312
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
313
|
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
314
|
+
|
|
315
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
316
|
+
bos_len = 1
|
|
317
|
+
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
|
|
318
|
+
|
|
319
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
320
|
+
grid_t = video_grid_thw[video_idx][0]
|
|
321
|
+
grid_hs = video_grid_thw[:, 1]
|
|
322
|
+
grid_ws = video_grid_thw[:, 2]
|
|
323
|
+
t_index = (
|
|
324
|
+
torch.arange(grid_t.item())
|
|
325
|
+
* second_per_grids[video_idx].cpu().float()
|
|
326
|
+
* position_id_per_seconds
|
|
327
|
+
).long()
|
|
328
|
+
llm_pos_ids = self.get_llm_pos_ids_for_vision(
|
|
329
|
+
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
|
330
|
+
)
|
|
331
|
+
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
|
332
|
+
llm_pos_ids_list.append(llm_pos_ids)
|
|
333
|
+
|
|
334
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
335
|
+
eos_len = 1
|
|
336
|
+
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
|
|
337
|
+
|
|
338
|
+
st += text_len + bos_len + video_len + eos_len
|
|
339
|
+
video_idx += 1
|
|
340
|
+
remain_videos -= 1
|
|
341
|
+
|
|
342
|
+
elif (
|
|
343
|
+
min_ed == ed_video
|
|
344
|
+
and use_audio_in_video
|
|
345
|
+
and audio_seqlens is not None
|
|
346
|
+
and video_grid_thw is not None
|
|
347
|
+
and second_per_grids is not None
|
|
348
|
+
):
|
|
349
|
+
text_len = min_ed - st - 2
|
|
350
|
+
if text_len != 0:
|
|
351
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
352
|
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
353
|
+
|
|
354
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
355
|
+
bos_len = 1
|
|
356
|
+
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
|
|
357
|
+
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
|
|
358
|
+
|
|
359
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
360
|
+
audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1
|
|
361
|
+
audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
|
|
362
|
+
grid_t = video_grid_thw[video_idx][0]
|
|
363
|
+
grid_hs = video_grid_thw[:, 1]
|
|
364
|
+
grid_ws = video_grid_thw[:, 2]
|
|
365
|
+
|
|
366
|
+
t_index = (
|
|
367
|
+
torch.arange(grid_t.item())
|
|
368
|
+
* second_per_grids[video_idx].cpu().float()
|
|
369
|
+
* position_id_per_seconds
|
|
370
|
+
).long()
|
|
371
|
+
video_llm_pos_ids = self.get_llm_pos_ids_for_vision(
|
|
372
|
+
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk)
|
|
376
|
+
video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids, t_ntoken_per_chunk, st_idx)
|
|
377
|
+
audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids, t_ntoken_per_chunk, st_idx)
|
|
378
|
+
sub_len = 0
|
|
379
|
+
for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
|
|
380
|
+
video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None
|
|
381
|
+
audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None
|
|
382
|
+
if video_chunk_index is not None:
|
|
383
|
+
sub_len += video_chunk_index[1] - video_chunk_index[0]
|
|
384
|
+
|
|
385
|
+
llm_pos_ids_list.append(
|
|
386
|
+
video_llm_pos_ids[:, video_chunk_index[0] : video_chunk_index[1]]
|
|
387
|
+
)
|
|
388
|
+
if audio_chunk_index is not None:
|
|
389
|
+
sub_len += audio_chunk_index[1] - audio_chunk_index[0]
|
|
390
|
+
|
|
391
|
+
llm_pos_ids_list.append(
|
|
392
|
+
audio_llm_pos_ids[:, audio_chunk_index[0] : audio_chunk_index[1]]
|
|
393
|
+
)
|
|
394
|
+
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
|
395
|
+
|
|
396
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
397
|
+
eos_len = 1
|
|
398
|
+
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
|
|
399
|
+
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
|
|
400
|
+
|
|
401
|
+
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2
|
|
402
|
+
|
|
403
|
+
audio_idx += 1
|
|
404
|
+
video_idx += 1
|
|
405
|
+
remain_videos -= 1
|
|
406
|
+
remain_audios -= 1
|
|
407
|
+
|
|
408
|
+
if st < len(input_tokens):
|
|
409
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
410
|
+
text_len = len(input_tokens) - st
|
|
411
|
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
412
|
+
|
|
413
|
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
|
414
|
+
|
|
415
|
+
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
|
416
|
+
mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids_p))
|
|
417
|
+
mrope_position_deltas_p = torch.tensor(mrope_position_deltas, device=input_ids_p.device).unsqueeze(1)
|
|
418
|
+
|
|
419
|
+
return position_ids, mrope_position_deltas_p
|
|
420
|
+
else:
|
|
421
|
+
if attention_mask is not None:
|
|
422
|
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
423
|
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
424
|
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
|
425
|
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
|
426
|
+
mrope_position_deltas_p = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
|
|
427
|
+
|
|
428
|
+
return position_ids, mrope_position_deltas_p
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
@dataclass
|
|
432
|
+
class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput):
|
|
433
|
+
|
|
434
|
+
loss: Optional[torch.FloatTensor] = None
|
|
435
|
+
logits: Optional[torch.FloatTensor] = None
|
|
436
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
|
437
|
+
hidden_states: Optional[Any] = None
|
|
438
|
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
439
|
+
attention_mask: Optional[torch.Tensor] = None
|
|
440
|
+
rope_deltas: Optional[torch.Tensor] = None
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class Qwen2_5OmniAudioAttention(nn.Module):
|
|
444
|
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
445
|
+
|
|
446
|
+
def __init__(
|
|
447
|
+
self,
|
|
448
|
+
embed_dim: int,
|
|
449
|
+
num_heads: int,
|
|
450
|
+
dropout: float = 0.0,
|
|
451
|
+
is_decoder: bool = False,
|
|
452
|
+
bias: bool = True,
|
|
453
|
+
is_causal: bool = False,
|
|
454
|
+
layer_idx: Optional[int] = None,
|
|
455
|
+
config: Optional[Qwen2_5OmniThinkerConfig] = None,
|
|
456
|
+
):
|
|
457
|
+
super().__init__()
|
|
458
|
+
self.embed_dim = embed_dim
|
|
459
|
+
self.num_heads = num_heads
|
|
460
|
+
self.dropout = dropout
|
|
461
|
+
self.head_dim = embed_dim // num_heads
|
|
462
|
+
self.config = config
|
|
463
|
+
|
|
464
|
+
if (self.head_dim * num_heads) != self.embed_dim:
|
|
465
|
+
raise ValueError(
|
|
466
|
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
|
467
|
+
f" and `num_heads`: {num_heads})."
|
|
468
|
+
)
|
|
469
|
+
self.scaling = self.head_dim**-0.5
|
|
470
|
+
self.is_decoder = is_decoder
|
|
471
|
+
self.is_causal = is_causal
|
|
472
|
+
|
|
473
|
+
if layer_idx is None and is_decoder:
|
|
474
|
+
logger.warning_once(
|
|
475
|
+
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
|
|
476
|
+
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
|
477
|
+
"when creating this class."
|
|
478
|
+
)
|
|
479
|
+
self.layer_idx = layer_idx
|
|
480
|
+
|
|
481
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
|
482
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
483
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
484
|
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
485
|
+
|
|
486
|
+
def forward(
|
|
487
|
+
self,
|
|
488
|
+
hidden_states: torch.Tensor,
|
|
489
|
+
key_value_states: Optional[torch.Tensor] = None,
|
|
490
|
+
past_key_value: Optional[EncoderDecoderCache] = None,
|
|
491
|
+
cu_seqlens: Optional[torch.Tensor] = None,
|
|
492
|
+
layer_head_mask: Optional[torch.Tensor] = None,
|
|
493
|
+
output_attentions: bool = False,
|
|
494
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
495
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
496
|
+
"""Input shape: Batch x Time x Channel"""
|
|
497
|
+
|
|
498
|
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
499
|
+
# for the decoder
|
|
500
|
+
is_cross_attention = key_value_states is not None
|
|
501
|
+
seq_length, _ = hidden_states.size()
|
|
502
|
+
|
|
503
|
+
# get query proj
|
|
504
|
+
# query_states = self.q_proj(hidden_states)
|
|
505
|
+
query_states = (hidden_states @ self.q_proj.weight.t()) + self.q_proj.bias
|
|
506
|
+
|
|
507
|
+
query_states = query_states.reshape(seq_length, self.num_heads, -1)
|
|
508
|
+
|
|
509
|
+
if past_key_value is not None:
|
|
510
|
+
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
|
511
|
+
if is_cross_attention:
|
|
512
|
+
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
|
513
|
+
past_key_value.is_updated[self.layer_idx] = True
|
|
514
|
+
past_key_value = past_key_value.cross_attention_cache
|
|
515
|
+
else:
|
|
516
|
+
past_key_value = past_key_value.self_attention_cache
|
|
517
|
+
|
|
518
|
+
# use key_value_states if cross attention
|
|
519
|
+
current_states = key_value_states if key_value_states is not None else hidden_states
|
|
520
|
+
if is_cross_attention and past_key_value and is_updated:
|
|
521
|
+
# reuse k,v, cross_attentions
|
|
522
|
+
key_states = past_key_value.key_cache[self.layer_idx]
|
|
523
|
+
value_states = past_key_value.value_cache[self.layer_idx]
|
|
524
|
+
else:
|
|
525
|
+
key_states = self.k_proj(current_states).reshape(seq_length, self.num_heads, -1)
|
|
526
|
+
value_states = self.v_proj(current_states).reshape(seq_length, self.num_heads, -1)
|
|
527
|
+
if past_key_value is not None:
|
|
528
|
+
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
|
529
|
+
cache_position = cache_position if not is_cross_attention else None
|
|
530
|
+
key_states, value_states = past_key_value.update(
|
|
531
|
+
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
query_states = query_states.transpose(0, 1)
|
|
535
|
+
key_states = key_states.transpose(0, 1)
|
|
536
|
+
value_states = value_states.transpose(0, 1)
|
|
537
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
|
|
538
|
+
|
|
539
|
+
attention_mask = torch.full(
|
|
540
|
+
[1, seq_length, key_states.shape[1]],
|
|
541
|
+
torch.finfo(query_states.dtype).min,
|
|
542
|
+
device=query_states.device,
|
|
543
|
+
dtype=query_states.dtype,
|
|
544
|
+
)
|
|
545
|
+
assert cu_seqlens is not None
|
|
546
|
+
for i in range(1, cu_seqlens.size(0)):
|
|
547
|
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
|
548
|
+
|
|
549
|
+
attn_weights = attn_weights + attention_mask
|
|
550
|
+
|
|
551
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
|
|
552
|
+
|
|
553
|
+
if layer_head_mask is not None:
|
|
554
|
+
if layer_head_mask.size() != (self.num_heads,):
|
|
555
|
+
raise ValueError(
|
|
556
|
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
|
557
|
+
f" {layer_head_mask.size()}"
|
|
558
|
+
)
|
|
559
|
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
|
|
560
|
+
|
|
561
|
+
attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1).reshape(seq_length, self.embed_dim)
|
|
562
|
+
|
|
563
|
+
attn_output = self.out_proj(attn_output)
|
|
564
|
+
|
|
565
|
+
return attn_output, attn_weights, past_key_value
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention):
|
|
569
|
+
|
|
570
|
+
def __init__(self, *args, **kwargs):
|
|
571
|
+
super().__init__(*args, **kwargs)
|
|
572
|
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
|
573
|
+
|
|
574
|
+
def forward(
|
|
575
|
+
self,
|
|
576
|
+
hidden_states: torch.Tensor,
|
|
577
|
+
key_value_states: Optional[torch.Tensor] = None,
|
|
578
|
+
past_key_value: Optional[EncoderDecoderCache] = None,
|
|
579
|
+
cu_seqlens: Optional[torch.Tensor] = None,
|
|
580
|
+
layer_head_mask: Optional[torch.Tensor] = None,
|
|
581
|
+
output_attentions: bool = False,
|
|
582
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
583
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
584
|
+
if isinstance(past_key_value, StaticCache):
|
|
585
|
+
raise ValueError(
|
|
586
|
+
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
|
|
587
|
+
"Use `attn_implementation='sdpa'` in the meantime, and open an issue "
|
|
588
|
+
"at https://github.com/huggingface/transformers"
|
|
589
|
+
)
|
|
590
|
+
# Qwen2.5OmniThinkerFlashAttention2 attention does not support output_attentions
|
|
591
|
+
if output_attentions:
|
|
592
|
+
raise ValueError("Qwen2.5OmniThinkerFlashAttention2 attention does not support output_attentions")
|
|
593
|
+
|
|
594
|
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
595
|
+
# for the decoder
|
|
596
|
+
is_cross_attention = key_value_states is not None
|
|
597
|
+
seq_length, all_dim = hidden_states.size()
|
|
598
|
+
query_states = (hidden_states @ self.q_proj.weight.t()) + (
|
|
599
|
+
self.q_proj.bias if self.q_proj.bias is not None else 0
|
|
600
|
+
)
|
|
601
|
+
query_states = query_states.reshape(seq_length, self.num_heads, -1)
|
|
602
|
+
|
|
603
|
+
if past_key_value is not None:
|
|
604
|
+
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
|
605
|
+
if is_cross_attention:
|
|
606
|
+
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
|
607
|
+
past_key_value.is_updated[self.layer_idx] = True
|
|
608
|
+
past_key_value = past_key_value.cross_attention_cache
|
|
609
|
+
else:
|
|
610
|
+
past_key_value = past_key_value.self_attention_cache
|
|
611
|
+
|
|
612
|
+
# use key_value_states if cross attention
|
|
613
|
+
current_states = key_value_states if key_value_states is not None else hidden_states
|
|
614
|
+
if is_cross_attention and past_key_value and is_updated:
|
|
615
|
+
# reuse k,v, cross_attentions
|
|
616
|
+
key_states = past_key_value.key_cache[self.layer_idx]
|
|
617
|
+
value_states = past_key_value.value_cache[self.layer_idx]
|
|
618
|
+
else:
|
|
619
|
+
key_states = (current_states @ self.k_proj.weight.t()) + (
|
|
620
|
+
self.k_proj.bias if self.k_proj.bias is not None else 0
|
|
621
|
+
)
|
|
622
|
+
key_states = key_states.reshape(seq_length, self.num_heads, -1)
|
|
623
|
+
value_states = (current_states @ self.v_proj.weight.t()) + (
|
|
624
|
+
self.v_proj.bias if self.v_proj.bias is not None else 0
|
|
625
|
+
)
|
|
626
|
+
value_states = value_states.reshape(seq_length, self.num_heads, -1)
|
|
627
|
+
|
|
628
|
+
if past_key_value is not None:
|
|
629
|
+
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
|
630
|
+
cache_position = cache_position if not is_cross_attention else None
|
|
631
|
+
key_states, value_states = past_key_value.update(
|
|
632
|
+
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
|
633
|
+
)
|
|
634
|
+
assert cu_seqlens is not None
|
|
635
|
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
|
636
|
+
attn_output = flash_attn_varlen_func(
|
|
637
|
+
query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
|
|
638
|
+
)
|
|
639
|
+
attn_output = attn_output.reshape(seq_length, all_dim)
|
|
640
|
+
attn_output = (attn_output @ self.out_proj.weight.t()) + (
|
|
641
|
+
self.out_proj.bias if self.out_proj.bias is not None else 0
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
if not output_attentions:
|
|
645
|
+
attn_weights = None
|
|
646
|
+
|
|
647
|
+
return attn_output, attn_weights, past_key_value
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
class Qwen2_5OmniAudioSdpaAttention(Qwen2_5OmniAudioAttention):
|
|
651
|
+
def forward(
|
|
652
|
+
self,
|
|
653
|
+
hidden_states: torch.Tensor,
|
|
654
|
+
key_value_states: Optional[torch.Tensor] = None,
|
|
655
|
+
past_key_value: Optional[EncoderDecoderCache] = None,
|
|
656
|
+
cu_seqlens: Optional[torch.Tensor] = None,
|
|
657
|
+
layer_head_mask: Optional[torch.Tensor] = None,
|
|
658
|
+
output_attentions: bool = False,
|
|
659
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
660
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
661
|
+
"""Input shape: Batch x Time x Channel"""
|
|
662
|
+
if output_attentions or layer_head_mask is not None:
|
|
663
|
+
logger.warning_once(
|
|
664
|
+
"Qwen2_5OmniThinkerModel is using Qwen2_5OmniThinkerSdpaAttention, but "
|
|
665
|
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` "
|
|
666
|
+
"or `layer_head_mask` not None. Falling back to the manual attention"
|
|
667
|
+
' implementation, but specifying the manual implementation will be required "'
|
|
668
|
+
'"from Transformers version v5.0.0 onwards. This warning can be removed using the argument"'
|
|
669
|
+
'" `attn_implementation="eager"` when loading the model.'
|
|
670
|
+
)
|
|
671
|
+
return super().forward(
|
|
672
|
+
hidden_states,
|
|
673
|
+
key_value_states=key_value_states,
|
|
674
|
+
past_key_value=past_key_value,
|
|
675
|
+
cu_seqlens=cu_seqlens,
|
|
676
|
+
layer_head_mask=layer_head_mask,
|
|
677
|
+
output_attentions=output_attentions,
|
|
678
|
+
cache_position=cache_position,
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
682
|
+
# for the decoder
|
|
683
|
+
is_cross_attention = key_value_states is not None
|
|
684
|
+
seq_length, _ = hidden_states.size()
|
|
685
|
+
|
|
686
|
+
# get query proj
|
|
687
|
+
query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
|
688
|
+
|
|
689
|
+
if past_key_value is not None:
|
|
690
|
+
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
|
691
|
+
if is_cross_attention:
|
|
692
|
+
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
|
693
|
+
past_key_value.is_updated[self.layer_idx] = True
|
|
694
|
+
past_key_value = past_key_value.cross_attention_cache
|
|
695
|
+
else:
|
|
696
|
+
past_key_value = past_key_value.self_attention_cache
|
|
697
|
+
|
|
698
|
+
# use key_value_states if cross attention
|
|
699
|
+
current_states = key_value_states if key_value_states is not None else hidden_states
|
|
700
|
+
if is_cross_attention and past_key_value and is_updated:
|
|
701
|
+
# reuse k,v, cross_attentions
|
|
702
|
+
key_states = past_key_value.key_cache[self.layer_idx]
|
|
703
|
+
value_states = past_key_value.value_cache[self.layer_idx]
|
|
704
|
+
else:
|
|
705
|
+
key_states = self.k_proj(current_states).reshape(seq_length, self.num_heads, -1)
|
|
706
|
+
value_states = self.v_proj(current_states).reshape(seq_length, self.num_heads, -1)
|
|
707
|
+
if past_key_value is not None:
|
|
708
|
+
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
|
709
|
+
cache_position = cache_position if not is_cross_attention else None
|
|
710
|
+
key_states, value_states = past_key_value.update(
|
|
711
|
+
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
attention_mask = torch.zeros([1, seq_length, key_states.shape[0]], device=query_states.device, dtype=torch.bool)
|
|
715
|
+
assert cu_seqlens is not None
|
|
716
|
+
for i in range(1, cu_seqlens.size(0)):
|
|
717
|
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
|
718
|
+
|
|
719
|
+
query_states = query_states.transpose(0, 1)
|
|
720
|
+
key_states = key_states.transpose(0, 1)
|
|
721
|
+
value_states = value_states.transpose(0, 1)
|
|
722
|
+
|
|
723
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
724
|
+
query_states,
|
|
725
|
+
key_states,
|
|
726
|
+
value_states,
|
|
727
|
+
attn_mask=attention_mask,
|
|
728
|
+
dropout_p=self.dropout if self.training else 0.0,
|
|
729
|
+
)
|
|
730
|
+
attn_output = attn_output.transpose(0, 1)
|
|
731
|
+
|
|
732
|
+
attn_output = attn_output.reshape(seq_length, self.embed_dim)
|
|
733
|
+
attn_output = self.out_proj(attn_output)
|
|
734
|
+
return attn_output, None, past_key_value
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES = {
|
|
738
|
+
"eager": Qwen2_5OmniAudioAttention,
|
|
739
|
+
"flash_attention_2": Qwen2_5OmniAudioFlashAttention2,
|
|
740
|
+
"sdpa": Qwen2_5OmniAudioSdpaAttention,
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
class Qwen2_5OmniAudioEncoderLayer(nn.Module):
|
|
745
|
+
def __init__(self, config: Qwen2_5OmniAudioEncoderConfig):
|
|
746
|
+
super().__init__()
|
|
747
|
+
self.embed_dim = config.d_model
|
|
748
|
+
self.self_attn = QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES[config._attn_implementation](
|
|
749
|
+
embed_dim=self.embed_dim,
|
|
750
|
+
num_heads=config.encoder_attention_heads,
|
|
751
|
+
dropout=config.attention_dropout,
|
|
752
|
+
config=config,
|
|
753
|
+
)
|
|
754
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
755
|
+
self.dropout = config.dropout
|
|
756
|
+
self.activation_fn = ACT2FN[config.activation_function]
|
|
757
|
+
self.activation_dropout = config.activation_dropout
|
|
758
|
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
|
759
|
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
|
760
|
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
761
|
+
|
|
762
|
+
def forward(
|
|
763
|
+
self,
|
|
764
|
+
hidden_states: torch.Tensor,
|
|
765
|
+
cu_seqlens: torch.Tensor,
|
|
766
|
+
layer_head_mask: torch.Tensor,
|
|
767
|
+
output_attentions: bool = False,
|
|
768
|
+
):
|
|
769
|
+
residual = hidden_states
|
|
770
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
771
|
+
hidden_states, attn_weights, _ = self.self_attn(
|
|
772
|
+
hidden_states=hidden_states,
|
|
773
|
+
cu_seqlens=cu_seqlens,
|
|
774
|
+
layer_head_mask=layer_head_mask,
|
|
775
|
+
output_attentions=output_attentions,
|
|
776
|
+
)
|
|
777
|
+
hidden_states = residual + hidden_states
|
|
778
|
+
residual = hidden_states
|
|
779
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
|
780
|
+
hidden_states = (hidden_states @ self.fc1.weight.t()) + (self.fc1.bias if self.fc1.bias is not None else 0)
|
|
781
|
+
hidden_states = self.activation_fn(hidden_states)
|
|
782
|
+
hidden_states = (hidden_states @ self.fc2.weight.t()) + (self.fc2.bias if self.fc2.bias is not None else 0)
|
|
783
|
+
hidden_states = residual + hidden_states
|
|
784
|
+
|
|
785
|
+
if hidden_states.dtype == torch.float16 and (
|
|
786
|
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
|
787
|
+
):
|
|
788
|
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
789
|
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
790
|
+
|
|
791
|
+
outputs: Tuple[Any, ...]
|
|
792
|
+
outputs = (hidden_states,)
|
|
793
|
+
|
|
794
|
+
if output_attentions and attn_weights is not None:
|
|
795
|
+
outputs += (attn_weights,)
|
|
796
|
+
|
|
797
|
+
return outputs
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
class SinusoidsPositionEmbedding(nn.Module):
|
|
801
|
+
def __init__(self, length, channels, max_timescale=10000):
|
|
802
|
+
super().__init__()
|
|
803
|
+
assert channels % 2 == 0
|
|
804
|
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
|
805
|
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
|
806
|
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
|
807
|
+
self.register_buffer(
|
|
808
|
+
"positional_embedding",
|
|
809
|
+
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
|
|
810
|
+
persistent=False,
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
def forward(self, seqlen: int):
|
|
814
|
+
return self.positional_embedding[:seqlen, :]
|
|
815
|
+
|
|
816
|
+
|
|
817
|
+
class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
|
|
818
|
+
"""
|
|
819
|
+
Transformer encoder consisting of *config.encoder_layers* self
|
|
820
|
+
attention layers. Each layer is a [`Qwen2_5OmniAudioEncoderLayer`].
|
|
821
|
+
|
|
822
|
+
Args:
|
|
823
|
+
config: Qwen2_5OmniAudioEncoderConfig
|
|
824
|
+
"""
|
|
825
|
+
|
|
826
|
+
config_class = Qwen2_5OmniAudioEncoderConfig
|
|
827
|
+
main_input_name = "input_features"
|
|
828
|
+
_no_split_modules = ["Qwen2_5OmniAudioEncoderLayer"]
|
|
829
|
+
_supports_sdpa = True
|
|
830
|
+
|
|
831
|
+
def __init__(self, config: Qwen2_5OmniAudioEncoderConfig):
|
|
832
|
+
super().__init__(config)
|
|
833
|
+
self.dropout = config.dropout
|
|
834
|
+
self.layerdrop = config.encoder_layerdrop
|
|
835
|
+
|
|
836
|
+
embed_dim = config.d_model
|
|
837
|
+
self.num_mel_bins = config.num_mel_bins
|
|
838
|
+
self.max_source_positions = config.max_source_positions
|
|
839
|
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
|
840
|
+
self.n_window = config.n_window
|
|
841
|
+
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
|
|
842
|
+
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
|
843
|
+
self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim)
|
|
844
|
+
self.audio_bos_eos_token = nn.Embedding(2, config.output_dim)
|
|
845
|
+
self.layers = nn.ModuleList([Qwen2_5OmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
846
|
+
self.ln_post = nn.LayerNorm(config.d_model)
|
|
847
|
+
self.avg_pooler = nn.AvgPool1d(2, stride=2)
|
|
848
|
+
self.proj = nn.Linear(config.d_model, config.output_dim)
|
|
849
|
+
self.gradient_checkpointing = False
|
|
850
|
+
# Initialize weights and apply final processing
|
|
851
|
+
self.post_init()
|
|
852
|
+
|
|
853
|
+
def _freeze_parameters(self):
|
|
854
|
+
for param in self.parameters():
|
|
855
|
+
param.requires_grad = False
|
|
856
|
+
self._requires_grad = False
|
|
857
|
+
|
|
858
|
+
def get_input_embeddings(self) -> nn.Module:
|
|
859
|
+
return self.conv1
|
|
860
|
+
|
|
861
|
+
def set_input_embeddings(self, value):
|
|
862
|
+
self.conv1 = value
|
|
863
|
+
|
|
864
|
+
def forward(
|
|
865
|
+
self,
|
|
866
|
+
input_features,
|
|
867
|
+
feature_lens=None,
|
|
868
|
+
aftercnn_lens=None,
|
|
869
|
+
head_mask=None,
|
|
870
|
+
output_attentions=None,
|
|
871
|
+
output_hidden_states=None,
|
|
872
|
+
return_dict=None,
|
|
873
|
+
):
|
|
874
|
+
|
|
875
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
876
|
+
output_hidden_states = (
|
|
877
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
878
|
+
)
|
|
879
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
880
|
+
|
|
881
|
+
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
|
|
882
|
+
|
|
883
|
+
chunk_lengths = torch.tensor(
|
|
884
|
+
[self.n_window * 2] * chunk_num.sum(),
|
|
885
|
+
dtype=torch.long,
|
|
886
|
+
device=feature_lens.device,
|
|
887
|
+
)
|
|
888
|
+
tail_chunk_index = list(accumulate(chunk_num.tolist(), func=operator.add, initial=-1))[1:]
|
|
889
|
+
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
|
|
890
|
+
chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths)
|
|
891
|
+
|
|
892
|
+
chunk_list = input_features.split(chunk_lengths.tolist(), dim=1)
|
|
893
|
+
padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function(
|
|
894
|
+
chunk_list, chunk_lengths, padding_value=0, padding_side="right"
|
|
895
|
+
)
|
|
896
|
+
padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask
|
|
897
|
+
padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2)
|
|
898
|
+
|
|
899
|
+
padded_embed = padded_embed + self.positional_embedding.positional_embedding[
|
|
900
|
+
: padded_embed.shape[1], :
|
|
901
|
+
].unsqueeze(0).to(padded_embed.dtype)
|
|
902
|
+
hidden_states = padded_embed[padded_mask_after_cnn]
|
|
903
|
+
cu_seqlens = torch.cat(
|
|
904
|
+
(
|
|
905
|
+
torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32),
|
|
906
|
+
padded_mask_after_cnn.sum(1).cumsum(0),
|
|
907
|
+
)
|
|
908
|
+
).to(torch.int32)
|
|
909
|
+
encoder_states: Optional[Tuple[Any, ...]] = () if output_hidden_states else None
|
|
910
|
+
all_attentions: Optional[Tuple[Any, ...]] = () if output_attentions else None
|
|
911
|
+
|
|
912
|
+
tmp_hidden_states = []
|
|
913
|
+
# check if head_mask has a correct number of layers specified if desired
|
|
914
|
+
if head_mask is not None and head_mask.size()[0] != (len(self.layers)):
|
|
915
|
+
raise ValueError(
|
|
916
|
+
f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
for idx, encoder_layer in enumerate(self.layers):
|
|
920
|
+
if output_hidden_states and encoder_states is not None and hidden_states is not None:
|
|
921
|
+
encoder_states = encoder_states + (hidden_states,)
|
|
922
|
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
|
923
|
+
to_drop = False
|
|
924
|
+
if self.training:
|
|
925
|
+
dropout_probability = torch.rand([])
|
|
926
|
+
if dropout_probability < self.layerdrop: # skip the layer
|
|
927
|
+
to_drop = True
|
|
928
|
+
|
|
929
|
+
# Ignore copy
|
|
930
|
+
if to_drop:
|
|
931
|
+
layer_outputs = (None, None)
|
|
932
|
+
else:
|
|
933
|
+
if self.gradient_checkpointing and self.training:
|
|
934
|
+
layer_outputs = self._gradient_checkpointing_func(
|
|
935
|
+
encoder_layer.__call__,
|
|
936
|
+
hidden_states,
|
|
937
|
+
cu_seqlens,
|
|
938
|
+
(head_mask[idx] if head_mask is not None else None),
|
|
939
|
+
output_attentions,
|
|
940
|
+
)
|
|
941
|
+
else:
|
|
942
|
+
layer_outputs = encoder_layer(
|
|
943
|
+
hidden_states,
|
|
944
|
+
cu_seqlens,
|
|
945
|
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
946
|
+
output_attentions=output_attentions,
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
hidden_states = layer_outputs[0]
|
|
950
|
+
tmp_hidden_states.append(hidden_states)
|
|
951
|
+
|
|
952
|
+
if output_attentions and all_attentions is not None and layer_outputs is not None:
|
|
953
|
+
all_attentions = all_attentions + (layer_outputs[1],)
|
|
954
|
+
|
|
955
|
+
hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
|
|
956
|
+
token_audio_list = []
|
|
957
|
+
for each_audio_states in hidden_states_list:
|
|
958
|
+
each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1)
|
|
959
|
+
each_audio_states = self.ln_post(each_audio_states)
|
|
960
|
+
each_audio_states = self.proj(each_audio_states)
|
|
961
|
+
token_audio_list.append(each_audio_states)
|
|
962
|
+
token_audio = torch.cat(token_audio_list, dim=0)
|
|
963
|
+
if output_hidden_states and encoder_states is not None and token_audio is not None:
|
|
964
|
+
encoder_states = encoder_states + (token_audio,)
|
|
965
|
+
|
|
966
|
+
if not return_dict:
|
|
967
|
+
return tuple(v for v in [token_audio, encoder_states, all_attentions] if v is not None)
|
|
968
|
+
return BaseModelOutput(last_hidden_state=token_audio, hidden_states=encoder_states, attentions=all_attentions)
|
|
969
|
+
|
|
970
|
+
def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"):
|
|
971
|
+
max_len = tensor_len.max()
|
|
972
|
+
dim = tensor_list[0].shape[0]
|
|
973
|
+
padded_tensor = torch.full(
|
|
974
|
+
size=(len(tensor_list), dim, max_len),
|
|
975
|
+
fill_value=padding_value,
|
|
976
|
+
dtype=tensor_list[0].dtype,
|
|
977
|
+
device=tensor_list[0].device,
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
batch_mask = torch.zeros(
|
|
981
|
+
(len(tensor_len), max_len),
|
|
982
|
+
dtype=torch.long,
|
|
983
|
+
device=padded_tensor.device,
|
|
984
|
+
)
|
|
985
|
+
for i, length in enumerate(tensor_len):
|
|
986
|
+
batch_mask[i, :length] = 1
|
|
987
|
+
padded_tensor[i, :, :length] = tensor_list[i]
|
|
988
|
+
|
|
989
|
+
feature_lens_after_cnn = (tensor_len - 1) // 2 + 1
|
|
990
|
+
max_len_after_cnn = feature_lens_after_cnn.max()
|
|
991
|
+
batch_mask_after_cnn = torch.zeros(
|
|
992
|
+
(len(tensor_len), max_len_after_cnn),
|
|
993
|
+
dtype=torch.long,
|
|
994
|
+
device=padded_tensor.device,
|
|
995
|
+
)
|
|
996
|
+
for i, length in enumerate(feature_lens_after_cnn):
|
|
997
|
+
batch_mask_after_cnn[i, :length] = 1
|
|
998
|
+
return (
|
|
999
|
+
padded_tensor,
|
|
1000
|
+
batch_mask.unsqueeze(1),
|
|
1001
|
+
batch_mask_after_cnn.bool(),
|
|
1002
|
+
)
|
|
1003
|
+
|
|
1004
|
+
# Ignore copy
|
|
1005
|
+
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
|
1006
|
+
"""
|
|
1007
|
+
Computes the output length of the convolutional layers and the output length of the audio encoder
|
|
1008
|
+
"""
|
|
1009
|
+
input_lengths = (input_lengths - 1) // 2 + 1
|
|
1010
|
+
output_lengths = (input_lengths - 2) // 2 + 1
|
|
1011
|
+
return input_lengths, output_lengths
|
|
1012
|
+
|
|
1013
|
+
|
|
1014
|
+
def rotate_half(x):
|
|
1015
|
+
"""Rotates half the hidden dims of the input."""
|
|
1016
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
1017
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
1018
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
|
1022
|
+
orig_dtype = tensor.dtype
|
|
1023
|
+
tensor = tensor.float()
|
|
1024
|
+
cos = freqs.cos()
|
|
1025
|
+
sin = freqs.sin()
|
|
1026
|
+
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
|
1027
|
+
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
|
1028
|
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
|
1029
|
+
output = output.to(orig_dtype)
|
|
1030
|
+
return output
|
|
1031
|
+
|
|
1032
|
+
|
|
1033
|
+
class Qwen2_5OmniVisionAttention(nn.Module):
|
|
1034
|
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
|
1035
|
+
super().__init__()
|
|
1036
|
+
self.num_heads = num_heads
|
|
1037
|
+
self.head_dim = dim // num_heads
|
|
1038
|
+
self.q = nn.Linear(dim, dim, bias=True)
|
|
1039
|
+
self.k = nn.Linear(dim, dim, bias=True)
|
|
1040
|
+
self.v = nn.Linear(dim, dim, bias=True)
|
|
1041
|
+
self.proj = nn.Linear(dim, dim)
|
|
1042
|
+
|
|
1043
|
+
def forward(
|
|
1044
|
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
|
1045
|
+
) -> torch.Tensor:
|
|
1046
|
+
seq_length = hidden_states.shape[0]
|
|
1047
|
+
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
|
|
1048
|
+
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
|
|
1049
|
+
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
|
|
1050
|
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
|
1051
|
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
|
1052
|
+
|
|
1053
|
+
attention_mask = torch.full(
|
|
1054
|
+
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
|
1055
|
+
)
|
|
1056
|
+
for i in range(1, len(cu_seqlens)):
|
|
1057
|
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
|
1058
|
+
|
|
1059
|
+
q = q.transpose(0, 1)
|
|
1060
|
+
k = k.transpose(0, 1)
|
|
1061
|
+
v = v.transpose(0, 1)
|
|
1062
|
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
|
1063
|
+
attn_weights = attn_weights + attention_mask
|
|
1064
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
|
1065
|
+
attn_output = torch.matmul(attn_weights, v)
|
|
1066
|
+
attn_output = attn_output.transpose(0, 1)
|
|
1067
|
+
attn_output = attn_output.reshape(seq_length, -1)
|
|
1068
|
+
attn_output = self.proj(attn_output)
|
|
1069
|
+
return attn_output
|
|
1070
|
+
|
|
1071
|
+
|
|
1072
|
+
class Qwen2_5OmniVisionFlashAttention2(nn.Module):
|
|
1073
|
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
|
1074
|
+
super().__init__()
|
|
1075
|
+
self.num_heads = num_heads
|
|
1076
|
+
self.q = nn.Linear(dim, dim, bias=True)
|
|
1077
|
+
self.k = nn.Linear(dim, dim, bias=True)
|
|
1078
|
+
self.v = nn.Linear(dim, dim, bias=True)
|
|
1079
|
+
self.proj = nn.Linear(dim, dim)
|
|
1080
|
+
|
|
1081
|
+
def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
|
1082
|
+
tensor_ = tensor.float()
|
|
1083
|
+
cos = freqs.cos() # .type_as(tensor_)
|
|
1084
|
+
sin = freqs.sin() # .type_as(tensor_)
|
|
1085
|
+
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
|
|
1086
|
+
return output
|
|
1087
|
+
|
|
1088
|
+
def forward(
|
|
1089
|
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
|
1090
|
+
) -> torch.Tensor:
|
|
1091
|
+
seq_length = hidden_states.shape[0]
|
|
1092
|
+
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
|
|
1093
|
+
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
|
|
1094
|
+
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
|
|
1095
|
+
q = self._apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
|
1096
|
+
k = self._apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
|
1097
|
+
|
|
1098
|
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
|
1099
|
+
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
|
1100
|
+
seq_length, -1
|
|
1101
|
+
)
|
|
1102
|
+
attn_output = self.proj(attn_output)
|
|
1103
|
+
return attn_output
|
|
1104
|
+
|
|
1105
|
+
|
|
1106
|
+
class Qwen2_5OmniVisionSdpaAttention(nn.Module):
|
|
1107
|
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
|
1108
|
+
super().__init__()
|
|
1109
|
+
self.num_heads = num_heads
|
|
1110
|
+
self.q = nn.Linear(dim, dim, bias=True)
|
|
1111
|
+
self.k = nn.Linear(dim, dim, bias=True)
|
|
1112
|
+
self.v = nn.Linear(dim, dim, bias=True)
|
|
1113
|
+
self.proj = nn.Linear(dim, dim)
|
|
1114
|
+
|
|
1115
|
+
def forward(
|
|
1116
|
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
|
1117
|
+
) -> torch.Tensor:
|
|
1118
|
+
seq_length = hidden_states.shape[0]
|
|
1119
|
+
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
|
|
1120
|
+
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
|
|
1121
|
+
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
|
|
1122
|
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
|
1123
|
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
|
1124
|
+
|
|
1125
|
+
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
|
1126
|
+
for i in range(1, len(cu_seqlens)):
|
|
1127
|
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
|
1128
|
+
q = q.transpose(0, 1)
|
|
1129
|
+
k = k.transpose(0, 1)
|
|
1130
|
+
v = v.transpose(0, 1)
|
|
1131
|
+
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
|
1132
|
+
attn_output = attn_output.transpose(0, 1)
|
|
1133
|
+
attn_output = attn_output.reshape(seq_length, -1)
|
|
1134
|
+
attn_output = self.proj(attn_output)
|
|
1135
|
+
return attn_output
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
class Qwen2_5OmniMLP(nn.Module):
|
|
1139
|
+
def __init__(self, config, bias: bool = False):
|
|
1140
|
+
super().__init__()
|
|
1141
|
+
self.hidden_size = config.hidden_size
|
|
1142
|
+
self.intermediate_size = config.intermediate_size
|
|
1143
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
|
|
1144
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
|
|
1145
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
|
|
1146
|
+
self.act_fn = ACT2FN[config.hidden_act]
|
|
1147
|
+
|
|
1148
|
+
def forward(self, hidden_state):
|
|
1149
|
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
|
1150
|
+
|
|
1151
|
+
|
|
1152
|
+
class Qwen2RMSNorm(nn.Module):
|
|
1153
|
+
def __init__(self, hidden_size, eps=1e-6):
|
|
1154
|
+
"""
|
|
1155
|
+
Qwen2RMSNorm is equivalent to T5LayerNorm
|
|
1156
|
+
"""
|
|
1157
|
+
super().__init__()
|
|
1158
|
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
1159
|
+
self.variance_epsilon = eps
|
|
1160
|
+
|
|
1161
|
+
def forward(self, hidden_states):
|
|
1162
|
+
input_dtype = hidden_states.dtype
|
|
1163
|
+
hidden_states = hidden_states.to(torch.float32)
|
|
1164
|
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
1165
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
1166
|
+
return self.weight * hidden_states.to(input_dtype)
|
|
1167
|
+
|
|
1168
|
+
def extra_repr(self):
|
|
1169
|
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
|
1170
|
+
|
|
1171
|
+
|
|
1172
|
+
QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = {
|
|
1173
|
+
"eager": Qwen2_5OmniVisionAttention,
|
|
1174
|
+
"flash_attention_2": Qwen2_5OmniVisionFlashAttention2,
|
|
1175
|
+
"sdpa": Qwen2_5OmniVisionSdpaAttention,
|
|
1176
|
+
}
|
|
1177
|
+
|
|
1178
|
+
|
|
1179
|
+
class Qwen2_5OmniVisionBlock(nn.Module):
|
|
1180
|
+
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
|
1181
|
+
super().__init__()
|
|
1182
|
+
self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
|
|
1183
|
+
self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
|
|
1184
|
+
self.attn = QWEN2_5_OMNI_VISION_ATTENTION_CLASSES[attn_implementation](
|
|
1185
|
+
config.hidden_size, num_heads=config.num_heads
|
|
1186
|
+
)
|
|
1187
|
+
self.mlp = Qwen2_5OmniMLP(config, bias=True)
|
|
1188
|
+
|
|
1189
|
+
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
|
1190
|
+
hidden_states = hidden_states + self.attn(
|
|
1191
|
+
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
|
1192
|
+
)
|
|
1193
|
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
|
1194
|
+
return hidden_states
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
|
1198
|
+
def __init__(
|
|
1199
|
+
self,
|
|
1200
|
+
patch_size: int = 14,
|
|
1201
|
+
temporal_patch_size: int = 2,
|
|
1202
|
+
in_channels: int = 3,
|
|
1203
|
+
embed_dim: int = 1152,
|
|
1204
|
+
) -> None:
|
|
1205
|
+
super().__init__()
|
|
1206
|
+
self.patch_size = patch_size
|
|
1207
|
+
self.temporal_patch_size = temporal_patch_size
|
|
1208
|
+
self.in_channels = in_channels
|
|
1209
|
+
self.embed_dim = embed_dim
|
|
1210
|
+
|
|
1211
|
+
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
|
1212
|
+
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
|
|
1213
|
+
|
|
1214
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
1215
|
+
target_dtype = self.proj.weight.dtype
|
|
1216
|
+
hidden_states = hidden_states.view(
|
|
1217
|
+
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
|
1218
|
+
)
|
|
1219
|
+
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
|
|
1220
|
+
return hidden_states
|
|
1221
|
+
|
|
1222
|
+
|
|
1223
|
+
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
|
|
1224
|
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
1225
|
+
super().__init__()
|
|
1226
|
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
|
1227
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
1228
|
+
|
|
1229
|
+
def forward(self, seqlen: int) -> torch.Tensor:
|
|
1230
|
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
|
1231
|
+
freqs = torch.outer(seq, self.inv_freq)
|
|
1232
|
+
return freqs
|
|
1233
|
+
|
|
1234
|
+
|
|
1235
|
+
class Qwen2_5OmniPatchMerger(nn.Module):
|
|
1236
|
+
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
|
|
1237
|
+
super().__init__()
|
|
1238
|
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
|
1239
|
+
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
|
|
1240
|
+
self.mlp = nn.Sequential(
|
|
1241
|
+
nn.Linear(self.hidden_size, self.hidden_size),
|
|
1242
|
+
nn.GELU(),
|
|
1243
|
+
nn.Linear(self.hidden_size, dim),
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1247
|
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
|
1248
|
+
return x
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
|
|
1252
|
+
config_class = Qwen2_5OmniVisionEncoderConfig
|
|
1253
|
+
_no_split_modules = ["Qwen2_5OmniVisionBlock"]
|
|
1254
|
+
|
|
1255
|
+
def __init__(self, config, *inputs, **kwargs) -> None:
|
|
1256
|
+
super().__init__(config, *inputs, **kwargs)
|
|
1257
|
+
self.spatial_merge_size = config.spatial_merge_size
|
|
1258
|
+
self.patch_size = config.patch_size
|
|
1259
|
+
self.fullatt_block_indexes = config.fullatt_block_indexes
|
|
1260
|
+
self.window_size = config.window_size
|
|
1261
|
+
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
|
1262
|
+
|
|
1263
|
+
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
|
1264
|
+
patch_size=config.patch_size,
|
|
1265
|
+
temporal_patch_size=config.temporal_patch_size,
|
|
1266
|
+
in_channels=config.in_channels,
|
|
1267
|
+
embed_dim=config.hidden_size,
|
|
1268
|
+
)
|
|
1269
|
+
|
|
1270
|
+
head_dim = config.hidden_size // config.num_heads
|
|
1271
|
+
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
|
1272
|
+
self.blocks = nn.ModuleList(
|
|
1273
|
+
[Qwen2_5OmniVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
|
|
1274
|
+
)
|
|
1275
|
+
self.merger = Qwen2_5OmniPatchMerger(
|
|
1276
|
+
dim=config.out_hidden_size,
|
|
1277
|
+
context_dim=config.hidden_size,
|
|
1278
|
+
spatial_merge_size=config.spatial_merge_size,
|
|
1279
|
+
)
|
|
1280
|
+
self.gradient_checkpointing = False
|
|
1281
|
+
|
|
1282
|
+
def rot_pos_emb(self, grid_thw):
|
|
1283
|
+
pos_ids = []
|
|
1284
|
+
for t, h, w in grid_thw:
|
|
1285
|
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
|
1286
|
+
hpos_ids = hpos_ids.reshape(
|
|
1287
|
+
h // self.spatial_merge_size,
|
|
1288
|
+
self.spatial_merge_size,
|
|
1289
|
+
w // self.spatial_merge_size,
|
|
1290
|
+
self.spatial_merge_size,
|
|
1291
|
+
)
|
|
1292
|
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
|
1293
|
+
hpos_ids = hpos_ids.flatten()
|
|
1294
|
+
|
|
1295
|
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
|
1296
|
+
wpos_ids = wpos_ids.reshape(
|
|
1297
|
+
h // self.spatial_merge_size,
|
|
1298
|
+
self.spatial_merge_size,
|
|
1299
|
+
w // self.spatial_merge_size,
|
|
1300
|
+
self.spatial_merge_size,
|
|
1301
|
+
)
|
|
1302
|
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
|
1303
|
+
wpos_ids = wpos_ids.flatten()
|
|
1304
|
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
|
1305
|
+
pos_ids_p = torch.cat(pos_ids, dim=0)
|
|
1306
|
+
max_grid_size = grid_thw[:, 1:].max()
|
|
1307
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
|
1308
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids_p].flatten(1)
|
|
1309
|
+
return rotary_pos_emb
|
|
1310
|
+
|
|
1311
|
+
def get_window_index(self, grid_thw):
|
|
1312
|
+
window_index: list = []
|
|
1313
|
+
cu_window_seqlens: list = [0]
|
|
1314
|
+
window_index_id = 0
|
|
1315
|
+
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
|
|
1316
|
+
|
|
1317
|
+
for grid_t, grid_h, grid_w in grid_thw:
|
|
1318
|
+
llm_grid_h, llm_grid_w = (
|
|
1319
|
+
grid_h // self.spatial_merge_size,
|
|
1320
|
+
grid_w // self.spatial_merge_size,
|
|
1321
|
+
)
|
|
1322
|
+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
|
|
1323
|
+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
|
1324
|
+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
|
1325
|
+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
|
1326
|
+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
|
1327
|
+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
|
1328
|
+
index_padded = index_padded.reshape(
|
|
1329
|
+
grid_t,
|
|
1330
|
+
num_windows_h,
|
|
1331
|
+
vit_merger_window_size,
|
|
1332
|
+
num_windows_w,
|
|
1333
|
+
vit_merger_window_size,
|
|
1334
|
+
)
|
|
1335
|
+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
|
1336
|
+
grid_t,
|
|
1337
|
+
num_windows_h * num_windows_w,
|
|
1338
|
+
vit_merger_window_size,
|
|
1339
|
+
vit_merger_window_size,
|
|
1340
|
+
)
|
|
1341
|
+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
|
1342
|
+
index_padded = index_padded.reshape(-1)
|
|
1343
|
+
index_new = index_padded[index_padded != -100]
|
|
1344
|
+
window_index.append(index_new + window_index_id)
|
|
1345
|
+
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
|
1346
|
+
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
|
1347
|
+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
|
1348
|
+
window_index_p = torch.cat(window_index, dim=0)
|
|
1349
|
+
|
|
1350
|
+
return window_index_p, cu_window_seqlens
|
|
1351
|
+
|
|
1352
|
+
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
|
1353
|
+
"""
|
|
1354
|
+
Args:
|
|
1355
|
+
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
|
1356
|
+
The final hidden states of the model.
|
|
1357
|
+
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
|
|
1358
|
+
The temporal, height and width of feature shape of each image in LLM.
|
|
1359
|
+
|
|
1360
|
+
Returns:
|
|
1361
|
+
`torch.Tensor`: hidden_states.
|
|
1362
|
+
"""
|
|
1363
|
+
hidden_states = self.patch_embed(hidden_states)
|
|
1364
|
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
1365
|
+
|
|
1366
|
+
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
|
1367
|
+
cu_window_seqlens = torch.tensor(
|
|
1368
|
+
cu_window_seqlens,
|
|
1369
|
+
device=hidden_states.device,
|
|
1370
|
+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
|
1371
|
+
)
|
|
1372
|
+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
|
1373
|
+
|
|
1374
|
+
seq_len, _ = hidden_states.size()
|
|
1375
|
+
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
|
1376
|
+
hidden_states = hidden_states[window_index, :, :]
|
|
1377
|
+
hidden_states = hidden_states.reshape(seq_len, -1)
|
|
1378
|
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
|
1379
|
+
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
|
1380
|
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
|
1381
|
+
|
|
1382
|
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
|
1383
|
+
dim=0,
|
|
1384
|
+
# Select dtype based on the following factors:
|
|
1385
|
+
# - FA2 requires that cu_seqlens_q must have dtype int32
|
|
1386
|
+
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
|
|
1387
|
+
# See https://github.com/huggingface/transformers/pull/34852 for more information
|
|
1388
|
+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
|
1389
|
+
)
|
|
1390
|
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
|
1391
|
+
|
|
1392
|
+
# Modification here
|
|
1393
|
+
for layer_num, blk in enumerate(self.blocks):
|
|
1394
|
+
if layer_num in self.fullatt_block_indexes:
|
|
1395
|
+
cu_seqlens_now = cu_seqlens
|
|
1396
|
+
else:
|
|
1397
|
+
cu_seqlens_now = cu_window_seqlens
|
|
1398
|
+
if self.gradient_checkpointing and self.training:
|
|
1399
|
+
hidden_states = self._gradient_checkpointing_func(
|
|
1400
|
+
blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb
|
|
1401
|
+
)
|
|
1402
|
+
else:
|
|
1403
|
+
hidden_states = blk(
|
|
1404
|
+
hidden_states,
|
|
1405
|
+
cu_seqlens=cu_seqlens_now,
|
|
1406
|
+
rotary_pos_emb=rotary_pos_emb,
|
|
1407
|
+
)
|
|
1408
|
+
hidden_states = self.merger(hidden_states)
|
|
1409
|
+
reverse_indices = torch.argsort(window_index)
|
|
1410
|
+
hidden_states = hidden_states[reverse_indices, :]
|
|
1411
|
+
|
|
1412
|
+
return hidden_states
|
|
1413
|
+
|
|
1414
|
+
def get_dtype(self) -> torch.dtype:
|
|
1415
|
+
return self.blocks[0].mlp.gate_proj.weight.dtype
|
|
1416
|
+
|
|
1417
|
+
def get_device(self) -> torch.device:
|
|
1418
|
+
return self.blocks[0].mlp.gate_proj.weight.device
|
|
1419
|
+
|
|
1420
|
+
|
|
1421
|
+
class Qwen2_5OmniRotaryEmbedding(nn.Module):
|
|
1422
|
+
def __init__(self, config: Qwen2_5OmniThinkerConfig, device=None):
|
|
1423
|
+
super().__init__()
|
|
1424
|
+
# BC: "rope_type" was originally "type"
|
|
1425
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
1426
|
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
1427
|
+
else:
|
|
1428
|
+
self.rope_type = "default"
|
|
1429
|
+
self.max_seq_len_cached = config.max_position_embeddings
|
|
1430
|
+
self.original_max_seq_len = config.max_position_embeddings
|
|
1431
|
+
|
|
1432
|
+
self.config = config
|
|
1433
|
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
1434
|
+
|
|
1435
|
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
|
1436
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
1437
|
+
self.original_inv_freq = self.inv_freq
|
|
1438
|
+
|
|
1439
|
+
def _dynamic_frequency_update(self, position_ids, device):
|
|
1440
|
+
"""
|
|
1441
|
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
|
1442
|
+
1 - growing beyond the cached sequence length (allow scaling)
|
|
1443
|
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
|
1444
|
+
"""
|
|
1445
|
+
seq_len = torch.max(position_ids) + 1
|
|
1446
|
+
if seq_len > self.max_seq_len_cached: # growth
|
|
1447
|
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
1448
|
+
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
1449
|
+
)
|
|
1450
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
|
1451
|
+
self.max_seq_len_cached = seq_len
|
|
1452
|
+
|
|
1453
|
+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
|
1454
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
1455
|
+
self.max_seq_len_cached = self.original_max_seq_len
|
|
1456
|
+
|
|
1457
|
+
@torch.no_grad()
|
|
1458
|
+
def forward(self, x, position_ids):
|
|
1459
|
+
if "dynamic" in self.rope_type:
|
|
1460
|
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
|
1461
|
+
|
|
1462
|
+
# Core RoPE block. In contrast to other models, Qwen2_5Omni has different position ids for the grids
|
|
1463
|
+
# So we expand the inv_freq to shape (3, ...)
|
|
1464
|
+
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
|
1465
|
+
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
|
1466
|
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
|
1467
|
+
device_type = x.device.type
|
|
1468
|
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
|
1469
|
+
with torch.autocast(device_type=device_type, enabled=False):
|
|
1470
|
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
|
1471
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1472
|
+
cos = emb.cos()
|
|
1473
|
+
sin = emb.sin()
|
|
1474
|
+
|
|
1475
|
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
|
1476
|
+
cos = cos * self.attention_scaling
|
|
1477
|
+
sin = sin * self.attention_scaling
|
|
1478
|
+
|
|
1479
|
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
1480
|
+
|
|
1481
|
+
|
|
1482
|
+
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
1483
|
+
mrope_section = mrope_section * 2
|
|
1484
|
+
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
|
|
1485
|
+
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
|
|
1486
|
+
|
|
1487
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
1488
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
1489
|
+
return q_embed, k_embed
|
|
1490
|
+
|
|
1491
|
+
|
|
1492
|
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
1493
|
+
"""
|
|
1494
|
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
|
1495
|
+
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
|
|
1496
|
+
(batch, num_attention_heads, seqlen, head_dim)
|
|
1497
|
+
"""
|
|
1498
|
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
1499
|
+
if n_rep == 1:
|
|
1500
|
+
return hidden_states
|
|
1501
|
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
1502
|
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
1503
|
+
|
|
1504
|
+
|
|
1505
|
+
class Qwen2_5OmniAttention(nn.Module):
|
|
1506
|
+
|
|
1507
|
+
def __init__(self, config: Qwen2_5OmniConfig, layer_idx: Optional[int] = None):
|
|
1508
|
+
super().__init__()
|
|
1509
|
+
self.config = config
|
|
1510
|
+
self.layer_idx = layer_idx
|
|
1511
|
+
if layer_idx is None:
|
|
1512
|
+
logger.warning_once(
|
|
1513
|
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
|
1514
|
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
|
1515
|
+
"when creating this class."
|
|
1516
|
+
)
|
|
1517
|
+
|
|
1518
|
+
self.hidden_size = config.hidden_size
|
|
1519
|
+
self.num_heads = config.num_attention_heads
|
|
1520
|
+
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
|
1521
|
+
self.num_key_value_heads = config.num_key_value_heads
|
|
1522
|
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
1523
|
+
self.is_causal = True
|
|
1524
|
+
self.attention_dropout = config.attention_dropout
|
|
1525
|
+
self.rope_scaling = config.rope_scaling
|
|
1526
|
+
|
|
1527
|
+
# if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
1528
|
+
# raise ValueError(
|
|
1529
|
+
# f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
1530
|
+
# f" and `num_heads`: {self.num_heads})."
|
|
1531
|
+
# )
|
|
1532
|
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
|
1533
|
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
|
1534
|
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
|
1535
|
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
1536
|
+
|
|
1537
|
+
self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
|
|
1538
|
+
|
|
1539
|
+
def forward(
|
|
1540
|
+
self,
|
|
1541
|
+
hidden_states: torch.Tensor,
|
|
1542
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
1543
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
1544
|
+
past_key_value: Optional[Cache] = None,
|
|
1545
|
+
output_attentions: bool = False,
|
|
1546
|
+
use_cache: bool = False,
|
|
1547
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
1548
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
|
1549
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
1550
|
+
bsz, q_len, _ = hidden_states.size()
|
|
1551
|
+
|
|
1552
|
+
query_states = self.q_proj(hidden_states)
|
|
1553
|
+
key_states = self.k_proj(hidden_states)
|
|
1554
|
+
value_states = self.v_proj(hidden_states)
|
|
1555
|
+
|
|
1556
|
+
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
1557
|
+
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
1558
|
+
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
1559
|
+
|
|
1560
|
+
assert position_embeddings is not None
|
|
1561
|
+
cos, sin = position_embeddings
|
|
1562
|
+
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
|
1563
|
+
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
|
1564
|
+
)
|
|
1565
|
+
|
|
1566
|
+
if past_key_value is not None:
|
|
1567
|
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
|
1568
|
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
1569
|
+
|
|
1570
|
+
# repeat k/v heads if n_kv_heads < n_heads
|
|
1571
|
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
1572
|
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
1573
|
+
|
|
1574
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
1575
|
+
|
|
1576
|
+
if attention_mask is not None: # no matter the length, we just slice it
|
|
1577
|
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
|
1578
|
+
attn_weights = attn_weights + causal_mask
|
|
1579
|
+
|
|
1580
|
+
# Fix precision issues in Qwen2-VL float16 inference
|
|
1581
|
+
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
|
1582
|
+
if query_states.dtype == torch.float16:
|
|
1583
|
+
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
|
|
1584
|
+
|
|
1585
|
+
# upcast attention to fp32
|
|
1586
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
1587
|
+
attn_weights_p: torch.Tensor = nn.functional.dropout(
|
|
1588
|
+
attn_weights, p=self.attention_dropout, training=self.training
|
|
1589
|
+
)
|
|
1590
|
+
return_attn_weights: Optional[torch.Tensor] = attn_weights_p
|
|
1591
|
+
attn_output = torch.matmul(attn_weights_p, value_states)
|
|
1592
|
+
|
|
1593
|
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
1594
|
+
raise ValueError(
|
|
1595
|
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
1596
|
+
f" {attn_output.size()}"
|
|
1597
|
+
)
|
|
1598
|
+
|
|
1599
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1600
|
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
|
1601
|
+
|
|
1602
|
+
attn_output = self.o_proj(attn_output)
|
|
1603
|
+
|
|
1604
|
+
if not output_attentions:
|
|
1605
|
+
return_attn_weights = None
|
|
1606
|
+
|
|
1607
|
+
return attn_output, return_attn_weights, past_key_value
|
|
1608
|
+
|
|
1609
|
+
|
|
1610
|
+
class Qwen2MLP(nn.Module):
|
|
1611
|
+
def __init__(self, config, bias: bool = False):
|
|
1612
|
+
super().__init__()
|
|
1613
|
+
self.hidden_size = config.hidden_size
|
|
1614
|
+
self.intermediate_size = config.intermediate_size
|
|
1615
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
|
|
1616
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
|
|
1617
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
|
|
1618
|
+
self.act_fn = ACT2FN[config.hidden_act]
|
|
1619
|
+
|
|
1620
|
+
def forward(self, hidden_state):
|
|
1621
|
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
|
1622
|
+
|
|
1623
|
+
|
|
1624
|
+
class Qwen2_5OmniFlashAttention2(Qwen2_5OmniAttention):
|
|
1625
|
+
|
|
1626
|
+
def __init__(self, *args, **kwargs):
|
|
1627
|
+
super().__init__(*args, **kwargs)
|
|
1628
|
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
|
1629
|
+
|
|
1630
|
+
def forward(
|
|
1631
|
+
self,
|
|
1632
|
+
hidden_states: torch.Tensor,
|
|
1633
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
1634
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
1635
|
+
past_key_value: Optional[Cache] = None,
|
|
1636
|
+
output_attentions: bool = False,
|
|
1637
|
+
use_cache: bool = False,
|
|
1638
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
1639
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
|
1640
|
+
):
|
|
1641
|
+
bsz, q_len, _ = hidden_states.size()
|
|
1642
|
+
|
|
1643
|
+
query_states = self.q_proj(hidden_states)
|
|
1644
|
+
key_states = self.k_proj(hidden_states)
|
|
1645
|
+
value_states = self.v_proj(hidden_states)
|
|
1646
|
+
|
|
1647
|
+
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
1648
|
+
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
1649
|
+
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
1650
|
+
|
|
1651
|
+
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
|
1652
|
+
assert position_embeddings is not None
|
|
1653
|
+
cos, sin = position_embeddings
|
|
1654
|
+
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
|
1655
|
+
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
|
1656
|
+
)
|
|
1657
|
+
|
|
1658
|
+
if past_key_value is not None:
|
|
1659
|
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
|
1660
|
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
1661
|
+
|
|
1662
|
+
# repeat k/v heads if n_kv_heads < n_heads
|
|
1663
|
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
1664
|
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
1665
|
+
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
|
1666
|
+
|
|
1667
|
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
|
1668
|
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
1669
|
+
# cast them back in float16 just to be sure everything works as expected.
|
|
1670
|
+
input_dtype = query_states.dtype
|
|
1671
|
+
if input_dtype == torch.float32:
|
|
1672
|
+
if torch.is_autocast_enabled():
|
|
1673
|
+
target_dtype = torch.get_autocast_gpu_dtype()
|
|
1674
|
+
# Handle the case where the model is quantized
|
|
1675
|
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
1676
|
+
target_dtype = self.config._pre_quantization_dtype
|
|
1677
|
+
else:
|
|
1678
|
+
target_dtype = self.q_proj.weight.dtype
|
|
1679
|
+
|
|
1680
|
+
logger.warning_once(
|
|
1681
|
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
|
1682
|
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
|
1683
|
+
f" {target_dtype}."
|
|
1684
|
+
)
|
|
1685
|
+
|
|
1686
|
+
query_states = query_states.to(target_dtype)
|
|
1687
|
+
key_states = key_states.to(target_dtype)
|
|
1688
|
+
value_states = value_states.to(target_dtype)
|
|
1689
|
+
|
|
1690
|
+
# Reashape to the expected shape for Flash Attention
|
|
1691
|
+
query_states = query_states.transpose(1, 2)
|
|
1692
|
+
key_states = key_states.transpose(1, 2)
|
|
1693
|
+
value_states = value_states.transpose(1, 2)
|
|
1694
|
+
|
|
1695
|
+
if (
|
|
1696
|
+
self.config.use_sliding_window
|
|
1697
|
+
and getattr(self.config, "sliding_window", None) is not None
|
|
1698
|
+
and self.layer_idx >= self.config.max_window_layers
|
|
1699
|
+
):
|
|
1700
|
+
sliding_window = self.config.sliding_window
|
|
1701
|
+
else:
|
|
1702
|
+
sliding_window = None
|
|
1703
|
+
|
|
1704
|
+
attn_output = _flash_attention_forward(
|
|
1705
|
+
query_states,
|
|
1706
|
+
key_states,
|
|
1707
|
+
value_states,
|
|
1708
|
+
attention_mask,
|
|
1709
|
+
q_len,
|
|
1710
|
+
dropout=dropout_rate,
|
|
1711
|
+
sliding_window=sliding_window,
|
|
1712
|
+
is_causal=self.is_causal,
|
|
1713
|
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
1714
|
+
)
|
|
1715
|
+
|
|
1716
|
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
1717
|
+
attn_output = self.o_proj(attn_output)
|
|
1718
|
+
|
|
1719
|
+
if not output_attentions:
|
|
1720
|
+
attn_weights = None
|
|
1721
|
+
|
|
1722
|
+
return attn_output, attn_weights, past_key_value
|
|
1723
|
+
|
|
1724
|
+
|
|
1725
|
+
class Qwen2_5OmniSdpaAttention(Qwen2_5OmniAttention):
|
|
1726
|
+
"""
|
|
1727
|
+
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
|
1728
|
+
`Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
|
1729
|
+
SDPA API.
|
|
1730
|
+
"""
|
|
1731
|
+
|
|
1732
|
+
# Adapted from Qwen2Attention.forward
|
|
1733
|
+
def forward(
|
|
1734
|
+
self,
|
|
1735
|
+
hidden_states: torch.Tensor,
|
|
1736
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
1737
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
1738
|
+
past_key_value: Optional[Cache] = None,
|
|
1739
|
+
output_attentions: bool = False,
|
|
1740
|
+
use_cache: bool = False,
|
|
1741
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
1742
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
1743
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
1744
|
+
if output_attentions:
|
|
1745
|
+
logger.warning_once(
|
|
1746
|
+
"Qwen2_5OmniModel is using Qwen2_5OmniSdpaAttention, but "
|
|
1747
|
+
"`torch.nn.functional.scaled_dot_product_attention`"
|
|
1748
|
+
" does not support `output_attentions=True`. Falling back to "
|
|
1749
|
+
"the manual attention implementation, "
|
|
1750
|
+
"but specifying the manual implementation will be required from "
|
|
1751
|
+
"Transformers version v5.0.0 onwards."
|
|
1752
|
+
' This warning can be removed using the argument "'
|
|
1753
|
+
'"`attn_implementation="eager"` when loading the model.'
|
|
1754
|
+
)
|
|
1755
|
+
return super().forward(
|
|
1756
|
+
hidden_states=hidden_states,
|
|
1757
|
+
attention_mask=attention_mask,
|
|
1758
|
+
position_ids=position_ids,
|
|
1759
|
+
past_key_value=past_key_value,
|
|
1760
|
+
output_attentions=output_attentions,
|
|
1761
|
+
use_cache=use_cache,
|
|
1762
|
+
cache_position=cache_position,
|
|
1763
|
+
position_embeddings=position_embeddings,
|
|
1764
|
+
)
|
|
1765
|
+
|
|
1766
|
+
bsz, q_len, _ = hidden_states.size()
|
|
1767
|
+
|
|
1768
|
+
query_states = self.q_proj(hidden_states)
|
|
1769
|
+
key_states = self.k_proj(hidden_states)
|
|
1770
|
+
value_states = self.v_proj(hidden_states)
|
|
1771
|
+
|
|
1772
|
+
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
1773
|
+
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
1774
|
+
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
1775
|
+
|
|
1776
|
+
assert position_embeddings is not None
|
|
1777
|
+
cos, sin = position_embeddings
|
|
1778
|
+
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
|
1779
|
+
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
|
1780
|
+
)
|
|
1781
|
+
|
|
1782
|
+
if past_key_value is not None:
|
|
1783
|
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
1784
|
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
1785
|
+
|
|
1786
|
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
1787
|
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
1788
|
+
|
|
1789
|
+
causal_mask = attention_mask
|
|
1790
|
+
if attention_mask is not None: # no matter the length, we just slice it
|
|
1791
|
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
|
1792
|
+
|
|
1793
|
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
|
1794
|
+
query_states = query_states.contiguous()
|
|
1795
|
+
key_states = key_states.contiguous()
|
|
1796
|
+
value_states = value_states.contiguous()
|
|
1797
|
+
|
|
1798
|
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
|
1799
|
+
|
|
1800
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
1801
|
+
query_states,
|
|
1802
|
+
key_states,
|
|
1803
|
+
value_states,
|
|
1804
|
+
attn_mask=causal_mask,
|
|
1805
|
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
|
1806
|
+
is_causal=is_causal,
|
|
1807
|
+
)
|
|
1808
|
+
|
|
1809
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1810
|
+
attn_output = attn_output.view(bsz, q_len, -1)
|
|
1811
|
+
|
|
1812
|
+
attn_output = self.o_proj(attn_output)
|
|
1813
|
+
|
|
1814
|
+
return attn_output, None, past_key_value
|
|
1815
|
+
|
|
1816
|
+
|
|
1817
|
+
QWEN2_5_OMNI_ATTENTION_CLASSES = {
|
|
1818
|
+
"eager": Qwen2_5OmniAttention,
|
|
1819
|
+
"flash_attention_2": Qwen2_5OmniFlashAttention2,
|
|
1820
|
+
"sdpa": Qwen2_5OmniSdpaAttention,
|
|
1821
|
+
}
|
|
1822
|
+
|
|
1823
|
+
|
|
1824
|
+
class Qwen2_5OmniDecoderLayer(nn.Module):
|
|
1825
|
+
def __init__(self, config: Qwen2_5OmniConfig, layer_idx: int):
|
|
1826
|
+
super().__init__()
|
|
1827
|
+
self.hidden_size = config.hidden_size
|
|
1828
|
+
|
|
1829
|
+
if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
|
|
1830
|
+
logger.warning_once(
|
|
1831
|
+
f"Sliding Window Attention is enabled but not implemented for "
|
|
1832
|
+
f"`{config._attn_implementation}`; "
|
|
1833
|
+
f"unexpected results may be encountered."
|
|
1834
|
+
)
|
|
1835
|
+
self.self_attn = QWEN2_5_OMNI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
|
1836
|
+
|
|
1837
|
+
self.mlp = Qwen2MLP(config)
|
|
1838
|
+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
1839
|
+
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
1840
|
+
|
|
1841
|
+
def forward(
|
|
1842
|
+
self,
|
|
1843
|
+
hidden_states: torch.Tensor,
|
|
1844
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
1845
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
1846
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
1847
|
+
output_attentions: Optional[bool] = False,
|
|
1848
|
+
use_cache: Optional[bool] = False,
|
|
1849
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
1850
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
|
1851
|
+
**kwargs,
|
|
1852
|
+
):
|
|
1853
|
+
|
|
1854
|
+
residual = hidden_states
|
|
1855
|
+
|
|
1856
|
+
hidden_states = self.input_layernorm(hidden_states)
|
|
1857
|
+
|
|
1858
|
+
# Self Attention
|
|
1859
|
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
1860
|
+
hidden_states=hidden_states,
|
|
1861
|
+
attention_mask=attention_mask,
|
|
1862
|
+
position_ids=position_ids,
|
|
1863
|
+
past_key_value=past_key_value,
|
|
1864
|
+
output_attentions=output_attentions,
|
|
1865
|
+
use_cache=use_cache,
|
|
1866
|
+
cache_position=cache_position,
|
|
1867
|
+
position_embeddings=position_embeddings,
|
|
1868
|
+
)
|
|
1869
|
+
hidden_states = residual + hidden_states
|
|
1870
|
+
|
|
1871
|
+
# Fully Connected
|
|
1872
|
+
residual = hidden_states
|
|
1873
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
1874
|
+
hidden_states = self.mlp(hidden_states)
|
|
1875
|
+
hidden_states = residual + hidden_states
|
|
1876
|
+
|
|
1877
|
+
outputs: Tuple[Any, ...]
|
|
1878
|
+
outputs = (hidden_states,)
|
|
1879
|
+
|
|
1880
|
+
if output_attentions:
|
|
1881
|
+
outputs += (self_attn_weights,)
|
|
1882
|
+
|
|
1883
|
+
if use_cache:
|
|
1884
|
+
outputs += (present_key_value,)
|
|
1885
|
+
|
|
1886
|
+
return outputs
|
|
1887
|
+
|
|
1888
|
+
|
|
1889
|
+
QWEN2_5OMNI_START_DOCSTRING = r"""add doc"""
|
|
1890
|
+
|
|
1891
|
+
|
|
1892
|
+
@add_start_docstrings(
|
|
1893
|
+
"The bare Qwen2.5OmniThinker Model outputting raw hidden-states without any specific head on top.",
|
|
1894
|
+
QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTextConfig"),
|
|
1895
|
+
)
|
|
1896
|
+
class Qwen2_5OmniThinkerModel(Qwen2_5OmniPreTrainedModel):
|
|
1897
|
+
config_class = Qwen2_5OmniTextConfig
|
|
1898
|
+
_no_split_modules = ["Qwen2_5OmniDecoderLayer"]
|
|
1899
|
+
|
|
1900
|
+
def __init__(self, config: Qwen2_5OmniTextConfig):
|
|
1901
|
+
super().__init__(config)
|
|
1902
|
+
self.padding_idx = config.pad_token_id
|
|
1903
|
+
self.vocab_size = config.vocab_size
|
|
1904
|
+
|
|
1905
|
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
1906
|
+
self.layers = nn.ModuleList(
|
|
1907
|
+
[Qwen2_5OmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
1908
|
+
)
|
|
1909
|
+
self._attn_implementation = config._attn_implementation
|
|
1910
|
+
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
1911
|
+
self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
|
|
1912
|
+
|
|
1913
|
+
self.gradient_checkpointing = False
|
|
1914
|
+
# Initialize weights and apply final processing
|
|
1915
|
+
self.post_init()
|
|
1916
|
+
|
|
1917
|
+
def get_input_embeddings(self):
|
|
1918
|
+
return self.embed_tokens
|
|
1919
|
+
|
|
1920
|
+
def set_input_embeddings(self, value):
|
|
1921
|
+
self.embed_tokens = value
|
|
1922
|
+
|
|
1923
|
+
def forward(
|
|
1924
|
+
self,
|
|
1925
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
1926
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
1927
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
1928
|
+
past_key_values=None,
|
|
1929
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
1930
|
+
use_cache: Optional[bool] = None,
|
|
1931
|
+
output_attentions: Optional[bool] = None,
|
|
1932
|
+
output_hidden_states: Optional[bool] = None,
|
|
1933
|
+
return_dict: Optional[bool] = None,
|
|
1934
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
1935
|
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
1936
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1937
|
+
output_hidden_states = (
|
|
1938
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1939
|
+
)
|
|
1940
|
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
1941
|
+
|
|
1942
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1943
|
+
|
|
1944
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
1945
|
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
1946
|
+
|
|
1947
|
+
if self.gradient_checkpointing and self.training:
|
|
1948
|
+
if use_cache:
|
|
1949
|
+
logger.warning_once(
|
|
1950
|
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
1951
|
+
)
|
|
1952
|
+
use_cache = False
|
|
1953
|
+
|
|
1954
|
+
# torch.jit.trace() doesn't support cache objects in the output
|
|
1955
|
+
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
|
1956
|
+
past_key_values = DynamicCache()
|
|
1957
|
+
|
|
1958
|
+
if inputs_embeds is None:
|
|
1959
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
|
1960
|
+
|
|
1961
|
+
if cache_position is None:
|
|
1962
|
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
1963
|
+
cache_position = torch.arange(
|
|
1964
|
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
1965
|
+
)
|
|
1966
|
+
|
|
1967
|
+
# the hard coded `3` is for temporal, height and width.
|
|
1968
|
+
if position_ids is None:
|
|
1969
|
+
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
|
1970
|
+
elif position_ids.dim() == 2:
|
|
1971
|
+
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
|
1972
|
+
|
|
1973
|
+
causal_mask = self._update_causal_mask(
|
|
1974
|
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
|
1975
|
+
)
|
|
1976
|
+
|
|
1977
|
+
hidden_states = inputs_embeds
|
|
1978
|
+
|
|
1979
|
+
# create position embeddings to be shared across the decoder layers
|
|
1980
|
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
1981
|
+
|
|
1982
|
+
# decoder layers
|
|
1983
|
+
all_hidden_states: Optional[Tuple[Any, ...]] = () if output_hidden_states else None
|
|
1984
|
+
all_self_attns: Optional[Tuple[Any, ...]] = () if output_attentions else None
|
|
1985
|
+
next_decoder_cache = None
|
|
1986
|
+
|
|
1987
|
+
for decoder_layer in self.layers:
|
|
1988
|
+
if output_hidden_states and hidden_states is not None and all_hidden_states is not None:
|
|
1989
|
+
all_hidden_states += (hidden_states,)
|
|
1990
|
+
|
|
1991
|
+
if self.gradient_checkpointing and self.training:
|
|
1992
|
+
layer_outputs = self._gradient_checkpointing_func(
|
|
1993
|
+
decoder_layer.__call__,
|
|
1994
|
+
hidden_states,
|
|
1995
|
+
causal_mask,
|
|
1996
|
+
position_ids,
|
|
1997
|
+
past_key_values,
|
|
1998
|
+
output_attentions,
|
|
1999
|
+
use_cache,
|
|
2000
|
+
cache_position,
|
|
2001
|
+
position_embeddings,
|
|
2002
|
+
)
|
|
2003
|
+
else:
|
|
2004
|
+
layer_outputs = decoder_layer(
|
|
2005
|
+
hidden_states,
|
|
2006
|
+
attention_mask=causal_mask,
|
|
2007
|
+
position_ids=position_ids,
|
|
2008
|
+
past_key_value=past_key_values,
|
|
2009
|
+
output_attentions=output_attentions,
|
|
2010
|
+
use_cache=use_cache,
|
|
2011
|
+
cache_position=cache_position,
|
|
2012
|
+
position_embeddings=position_embeddings,
|
|
2013
|
+
)
|
|
2014
|
+
|
|
2015
|
+
hidden_states = layer_outputs[0]
|
|
2016
|
+
|
|
2017
|
+
if use_cache:
|
|
2018
|
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
2019
|
+
|
|
2020
|
+
if output_attentions and layer_outputs is not None and all_self_attns is not None:
|
|
2021
|
+
all_self_attns += (layer_outputs[1],)
|
|
2022
|
+
|
|
2023
|
+
hidden_states = self.norm(hidden_states)
|
|
2024
|
+
|
|
2025
|
+
# add hidden states from the last decoder layer
|
|
2026
|
+
if output_hidden_states and all_hidden_states is not None:
|
|
2027
|
+
all_hidden_states += (hidden_states,)
|
|
2028
|
+
|
|
2029
|
+
next_cache = next_decoder_cache if use_cache else None
|
|
2030
|
+
|
|
2031
|
+
if not return_dict:
|
|
2032
|
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
2033
|
+
return BaseModelOutputWithPast(
|
|
2034
|
+
last_hidden_state=hidden_states,
|
|
2035
|
+
past_key_values=next_cache,
|
|
2036
|
+
hidden_states=all_hidden_states,
|
|
2037
|
+
attentions=all_self_attns,
|
|
2038
|
+
)
|
|
2039
|
+
|
|
2040
|
+
def _update_causal_mask(
|
|
2041
|
+
self,
|
|
2042
|
+
attention_mask,
|
|
2043
|
+
input_tensor: torch.Tensor,
|
|
2044
|
+
cache_position: torch.Tensor,
|
|
2045
|
+
past_key_values: Cache,
|
|
2046
|
+
output_attentions: bool,
|
|
2047
|
+
):
|
|
2048
|
+
if self.config._attn_implementation == "flash_attention_2":
|
|
2049
|
+
if attention_mask is not None and past_key_values is not None:
|
|
2050
|
+
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
|
2051
|
+
if is_padding_right:
|
|
2052
|
+
raise ValueError(
|
|
2053
|
+
"You are attempting to perform batched generation with padding_side='right'"
|
|
2054
|
+
" this may lead to unexpected behaviour for Flash Attention version "
|
|
2055
|
+
"of Qwen25OmniThinker. Make sure to "
|
|
2056
|
+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
|
2057
|
+
)
|
|
2058
|
+
if attention_mask is not None and 0.0 in attention_mask:
|
|
2059
|
+
return attention_mask
|
|
2060
|
+
return None
|
|
2061
|
+
|
|
2062
|
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
|
2063
|
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
|
2064
|
+
# to infer the attention mask.
|
|
2065
|
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
2066
|
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
|
2067
|
+
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
|
2068
|
+
|
|
2069
|
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
|
2070
|
+
if (
|
|
2071
|
+
self.config._attn_implementation == "sdpa"
|
|
2072
|
+
and not (using_static_cache or using_sliding_window_cache)
|
|
2073
|
+
and not output_attentions
|
|
2074
|
+
):
|
|
2075
|
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
|
2076
|
+
attention_mask,
|
|
2077
|
+
inputs_embeds=input_tensor,
|
|
2078
|
+
past_key_values_length=past_seen_tokens,
|
|
2079
|
+
sliding_window=self.config.sliding_window,
|
|
2080
|
+
is_training=self.training,
|
|
2081
|
+
):
|
|
2082
|
+
return None
|
|
2083
|
+
|
|
2084
|
+
dtype, device = input_tensor.dtype, input_tensor.device
|
|
2085
|
+
min_dtype = torch.finfo(dtype).min
|
|
2086
|
+
sequence_length = input_tensor.shape[1]
|
|
2087
|
+
# SlidingWindowCache or StaticCache
|
|
2088
|
+
if using_sliding_window_cache or using_static_cache:
|
|
2089
|
+
target_length = past_key_values.get_max_cache_shape()
|
|
2090
|
+
# DynamicCache or no cache
|
|
2091
|
+
else:
|
|
2092
|
+
target_length = (
|
|
2093
|
+
attention_mask.shape[-1]
|
|
2094
|
+
if isinstance(attention_mask, torch.Tensor)
|
|
2095
|
+
else past_seen_tokens + sequence_length + 1
|
|
2096
|
+
)
|
|
2097
|
+
|
|
2098
|
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
|
2099
|
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
|
2100
|
+
attention_mask,
|
|
2101
|
+
sequence_length=sequence_length,
|
|
2102
|
+
target_length=target_length,
|
|
2103
|
+
dtype=dtype,
|
|
2104
|
+
device=device,
|
|
2105
|
+
cache_position=cache_position,
|
|
2106
|
+
batch_size=input_tensor.shape[0],
|
|
2107
|
+
config=self.config,
|
|
2108
|
+
past_key_values=past_key_values,
|
|
2109
|
+
)
|
|
2110
|
+
|
|
2111
|
+
if (
|
|
2112
|
+
self.config._attn_implementation == "sdpa"
|
|
2113
|
+
and attention_mask is not None
|
|
2114
|
+
and attention_mask.device.type in ["cuda", "xpu"]
|
|
2115
|
+
and not output_attentions
|
|
2116
|
+
):
|
|
2117
|
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
|
2118
|
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
|
2119
|
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
|
2120
|
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
|
2121
|
+
|
|
2122
|
+
return causal_mask
|
|
2123
|
+
|
|
2124
|
+
@staticmethod
|
|
2125
|
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
2126
|
+
attention_mask: torch.Tensor,
|
|
2127
|
+
sequence_length: int,
|
|
2128
|
+
target_length: int,
|
|
2129
|
+
dtype: torch.dtype,
|
|
2130
|
+
device: torch.device,
|
|
2131
|
+
cache_position: torch.Tensor,
|
|
2132
|
+
batch_size: int,
|
|
2133
|
+
config: Qwen2_5OmniConfig,
|
|
2134
|
+
past_key_values: Cache,
|
|
2135
|
+
):
|
|
2136
|
+
if attention_mask is not None and attention_mask.dim() == 4:
|
|
2137
|
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
2138
|
+
causal_mask = attention_mask
|
|
2139
|
+
else:
|
|
2140
|
+
min_dtype = torch.finfo(dtype).min
|
|
2141
|
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
2142
|
+
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
|
2143
|
+
if config.sliding_window is not None:
|
|
2144
|
+
|
|
2145
|
+
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
|
2146
|
+
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
|
2147
|
+
cache_position.reshape(-1, 1) - config.sliding_window
|
|
2148
|
+
)
|
|
2149
|
+
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
|
2150
|
+
causal_mask *= diagonal_attend_mask
|
|
2151
|
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
2152
|
+
if attention_mask is not None:
|
|
2153
|
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
2154
|
+
if attention_mask.shape[-1] > target_length:
|
|
2155
|
+
attention_mask = attention_mask[:, :target_length]
|
|
2156
|
+
mask_length = attention_mask.shape[-1]
|
|
2157
|
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
|
2158
|
+
causal_mask.device
|
|
2159
|
+
)
|
|
2160
|
+
padding_mask = padding_mask == 0
|
|
2161
|
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
2162
|
+
padding_mask, min_dtype
|
|
2163
|
+
)
|
|
2164
|
+
return causal_mask
|
|
2165
|
+
|
|
2166
|
+
|
|
2167
|
+
@add_start_docstrings(
|
|
2168
|
+
"""The Qwen2.5OmniThinker model which consists of a audio backbone and a language model.""",
|
|
2169
|
+
QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniThinkerConfig"),
|
|
2170
|
+
)
|
|
2171
|
+
class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin):
|
|
2172
|
+
config_class = Qwen2_5OmniThinkerConfig
|
|
2173
|
+
_no_split_modules = ["Qwen2_5OmniAudioEncoder", "Qwen2_5OmniVisionEncoder"]
|
|
2174
|
+
|
|
2175
|
+
def __init__(self, config: Qwen2_5OmniThinkerConfig):
|
|
2176
|
+
super().__init__(config)
|
|
2177
|
+
self.audio_tower = Qwen2_5OmniAudioEncoder._from_config(
|
|
2178
|
+
config.audio_config, attn_implementation=config._attn_implementation
|
|
2179
|
+
)
|
|
2180
|
+
|
|
2181
|
+
self.visual = Qwen2_5OmniVisionEncoder._from_config(
|
|
2182
|
+
config.vision_config, attn_implementation=config._attn_implementation
|
|
2183
|
+
)
|
|
2184
|
+
|
|
2185
|
+
self.vocab_size = config.text_config.vocab_size
|
|
2186
|
+
self.model = Qwen2_5OmniThinkerModel._from_config(
|
|
2187
|
+
config.text_config, attn_implementation=config._attn_implementation
|
|
2188
|
+
)
|
|
2189
|
+
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
|
2190
|
+
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
|
2191
|
+
self.spatial_merge_size = config.vision_config.spatial_merge_size
|
|
2192
|
+
self.post_init()
|
|
2193
|
+
|
|
2194
|
+
def forward(
|
|
2195
|
+
self,
|
|
2196
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
2197
|
+
input_features: Optional[torch.Tensor] = None,
|
|
2198
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
2199
|
+
pixel_values_videos: Optional[torch.Tensor] = None,
|
|
2200
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
2201
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
2202
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
2203
|
+
feature_attention_mask: Optional[torch.Tensor] = None,
|
|
2204
|
+
audio_feature_lengths: Optional[torch.Tensor] = None,
|
|
2205
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
2206
|
+
past_key_values: Optional[List[torch.Tensor]] = None,
|
|
2207
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
2208
|
+
rope_deltas: Optional[torch.Tensor] = None,
|
|
2209
|
+
labels: Optional[torch.LongTensor] = None,
|
|
2210
|
+
use_cache: Optional[bool] = None,
|
|
2211
|
+
output_attentions: Optional[bool] = None,
|
|
2212
|
+
output_hidden_states: Optional[bool] = None,
|
|
2213
|
+
return_dict: Optional[bool] = None,
|
|
2214
|
+
use_audio_in_video: Optional[bool] = None,
|
|
2215
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
2216
|
+
video_second_per_grid: Optional[torch.LongTensor] = None,
|
|
2217
|
+
) -> Union[Tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]:
|
|
2218
|
+
|
|
2219
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
2220
|
+
output_hidden_states = (
|
|
2221
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
2222
|
+
)
|
|
2223
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
2224
|
+
|
|
2225
|
+
if feature_attention_mask is not None and input_features is not None:
|
|
2226
|
+
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
|
2227
|
+
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
|
|
2228
|
+
else:
|
|
2229
|
+
audio_feature_lengths = None
|
|
2230
|
+
if attention_mask is not None and position_ids is None:
|
|
2231
|
+
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
|
|
2232
|
+
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
|
|
2233
|
+
position_ids_p, rope_deltas = self.get_rope_index(
|
|
2234
|
+
input_ids,
|
|
2235
|
+
image_grid_thw,
|
|
2236
|
+
video_grid_thw,
|
|
2237
|
+
attention_mask,
|
|
2238
|
+
use_audio_in_video,
|
|
2239
|
+
audio_feature_lengths,
|
|
2240
|
+
video_second_per_grid,
|
|
2241
|
+
)
|
|
2242
|
+
rope_deltas = rope_deltas - delta0
|
|
2243
|
+
|
|
2244
|
+
else:
|
|
2245
|
+
assert input_ids is not None
|
|
2246
|
+
batch_size, seq_length = input_ids.shape
|
|
2247
|
+
delta = (
|
|
2248
|
+
cache_position[0] + rope_deltas
|
|
2249
|
+
if cache_position is not None and rope_deltas is not None
|
|
2250
|
+
else torch.tensor(0, device=input_ids.device)
|
|
2251
|
+
)
|
|
2252
|
+
position_ids = torch.arange(seq_length, device=input_ids.device)
|
|
2253
|
+
position_ids_p = position_ids.view(1, -1).expand(batch_size, -1)
|
|
2254
|
+
position_ids_p = position_ids_p.add(delta)
|
|
2255
|
+
position_ids_p = position_ids_p.unsqueeze(0).expand(3, -1, -1)
|
|
2256
|
+
|
|
2257
|
+
if inputs_embeds is None and input_ids is not None:
|
|
2258
|
+
# 1. Extract the input embeddings
|
|
2259
|
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
2260
|
+
embeds_to_talker = inputs_embeds.clone()
|
|
2261
|
+
|
|
2262
|
+
# 2. Merge text , audios , image and video
|
|
2263
|
+
if input_ids.shape[1] != 1:
|
|
2264
|
+
if input_features is not None and feature_attention_mask is not None:
|
|
2265
|
+
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
|
|
2266
|
+
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
|
2267
|
+
)
|
|
2268
|
+
feature_lens = (
|
|
2269
|
+
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
|
2270
|
+
)
|
|
2271
|
+
audio_outputs = self.audio_tower(
|
|
2272
|
+
input_features,
|
|
2273
|
+
feature_lens=feature_lens,
|
|
2274
|
+
aftercnn_lens=audio_feat_lengths,
|
|
2275
|
+
)
|
|
2276
|
+
audio_features = audio_outputs.last_hidden_state
|
|
2277
|
+
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
|
|
2278
|
+
raise ValueError("length of audio_features should match audio_output_lengths")
|
|
2279
|
+
audio_mask = (input_ids == self.config.audio_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
|
2280
|
+
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
2281
|
+
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
|
2282
|
+
embeds_to_talker = embeds_to_talker.masked_scatter(audio_mask, torch.zeros_like(audio_features))
|
|
2283
|
+
|
|
2284
|
+
if pixel_values is not None:
|
|
2285
|
+
pixel_values = pixel_values.type(self.visual.get_dtype())
|
|
2286
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
2287
|
+
image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
|
2288
|
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
2289
|
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
2290
|
+
embeds_to_talker = embeds_to_talker.masked_scatter(image_mask, torch.zeros_like(image_embeds))
|
|
2291
|
+
|
|
2292
|
+
if pixel_values_videos is not None:
|
|
2293
|
+
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
|
2294
|
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
2295
|
+
video_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
|
2296
|
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
2297
|
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
2298
|
+
embeds_to_talker = embeds_to_talker.masked_scatter(video_mask, torch.zeros_like(video_embeds))
|
|
2299
|
+
|
|
2300
|
+
if attention_mask is not None:
|
|
2301
|
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
2302
|
+
|
|
2303
|
+
outputs = self.model(
|
|
2304
|
+
attention_mask=attention_mask,
|
|
2305
|
+
position_ids=position_ids_p,
|
|
2306
|
+
past_key_values=past_key_values,
|
|
2307
|
+
inputs_embeds=inputs_embeds,
|
|
2308
|
+
use_cache=use_cache,
|
|
2309
|
+
output_attentions=output_attentions,
|
|
2310
|
+
output_hidden_states=output_hidden_states,
|
|
2311
|
+
return_dict=return_dict,
|
|
2312
|
+
cache_position=cache_position,
|
|
2313
|
+
)
|
|
2314
|
+
|
|
2315
|
+
hidden_states = outputs[0]
|
|
2316
|
+
logits = self.lm_head(hidden_states)
|
|
2317
|
+
|
|
2318
|
+
loss = None
|
|
2319
|
+
if labels is not None:
|
|
2320
|
+
logits = logits.float()
|
|
2321
|
+
# Shift so that tokens < n predict n
|
|
2322
|
+
if attention_mask is not None:
|
|
2323
|
+
shift_attention_mask = attention_mask[..., 1:]
|
|
2324
|
+
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
|
2325
|
+
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
|
2326
|
+
else:
|
|
2327
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
2328
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
2329
|
+
# Flatten the tokens
|
|
2330
|
+
loss_fct = nn.CrossEntropyLoss()
|
|
2331
|
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))
|
|
2332
|
+
|
|
2333
|
+
if not return_dict:
|
|
2334
|
+
output = (logits,) + ((embeds_to_talker, outputs[0])) + outputs[1:]
|
|
2335
|
+
return (loss,) + output if loss is not None else output
|
|
2336
|
+
|
|
2337
|
+
return Qwen2_5OmniThinkerCausalLMOutputWithPast(
|
|
2338
|
+
loss=loss,
|
|
2339
|
+
logits=logits,
|
|
2340
|
+
past_key_values=outputs.past_key_values,
|
|
2341
|
+
hidden_states=(embeds_to_talker, outputs.hidden_states),
|
|
2342
|
+
attentions=outputs.attentions,
|
|
2343
|
+
attention_mask=attention_mask,
|
|
2344
|
+
rope_deltas=rope_deltas,
|
|
2345
|
+
)
|
|
2346
|
+
|
|
2347
|
+
def prepare_inputs_for_generation(
|
|
2348
|
+
self,
|
|
2349
|
+
input_ids,
|
|
2350
|
+
past_key_values=None,
|
|
2351
|
+
attention_mask=None,
|
|
2352
|
+
inputs_embeds=None,
|
|
2353
|
+
cache_position=None,
|
|
2354
|
+
position_ids=None,
|
|
2355
|
+
use_cache=True,
|
|
2356
|
+
pixel_values=None,
|
|
2357
|
+
pixel_values_videos=None,
|
|
2358
|
+
image_grid_thw=None,
|
|
2359
|
+
video_grid_thw=None,
|
|
2360
|
+
input_features=None,
|
|
2361
|
+
feature_attention_mask=None,
|
|
2362
|
+
use_audio_in_video=False,
|
|
2363
|
+
video_second_per_grid=None,
|
|
2364
|
+
**kwargs,
|
|
2365
|
+
):
|
|
2366
|
+
model_inputs = super().prepare_inputs_for_generation(
|
|
2367
|
+
input_ids,
|
|
2368
|
+
past_key_values=past_key_values,
|
|
2369
|
+
attention_mask=attention_mask,
|
|
2370
|
+
inputs_embeds=inputs_embeds,
|
|
2371
|
+
cache_position=cache_position,
|
|
2372
|
+
position_ids=position_ids,
|
|
2373
|
+
use_cache=use_cache,
|
|
2374
|
+
pixel_values=pixel_values,
|
|
2375
|
+
pixel_values_videos=pixel_values_videos,
|
|
2376
|
+
image_grid_thw=image_grid_thw,
|
|
2377
|
+
video_grid_thw=video_grid_thw,
|
|
2378
|
+
input_features=input_features,
|
|
2379
|
+
feature_attention_mask=feature_attention_mask,
|
|
2380
|
+
use_audio_in_video=use_audio_in_video,
|
|
2381
|
+
video_second_per_grid=video_second_per_grid,
|
|
2382
|
+
**kwargs,
|
|
2383
|
+
)
|
|
2384
|
+
|
|
2385
|
+
model_inputs["position_ids"] = None
|
|
2386
|
+
|
|
2387
|
+
if cache_position[0] != 0:
|
|
2388
|
+
model_inputs["pixel_values"] = None
|
|
2389
|
+
model_inputs["pixel_values_videos"] = None
|
|
2390
|
+
|
|
2391
|
+
return model_inputs
|
|
2392
|
+
|
|
2393
|
+
def _update_model_kwargs_for_generation(
|
|
2394
|
+
self,
|
|
2395
|
+
outputs: ModelOutput,
|
|
2396
|
+
model_kwargs: Dict[str, Any],
|
|
2397
|
+
is_encoder_decoder: bool = False,
|
|
2398
|
+
num_new_tokens: int = 1,
|
|
2399
|
+
) -> Dict[str, Any]:
|
|
2400
|
+
# update attention_mask
|
|
2401
|
+
if getattr(outputs, "attention_mask", None) is not None:
|
|
2402
|
+
model_kwargs["attention_mask"] = outputs.attention_mask
|
|
2403
|
+
|
|
2404
|
+
model_kwargs = super()._update_model_kwargs_for_generation(
|
|
2405
|
+
outputs, model_kwargs, is_encoder_decoder, num_new_tokens
|
|
2406
|
+
)
|
|
2407
|
+
|
|
2408
|
+
if getattr(outputs, "rope_deltas", None) is not None:
|
|
2409
|
+
model_kwargs["rope_deltas"] = outputs.rope_deltas
|
|
2410
|
+
|
|
2411
|
+
return model_kwargs
|
|
2412
|
+
|
|
2413
|
+
|
|
2414
|
+
@dataclass
|
|
2415
|
+
class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput):
|
|
2416
|
+
|
|
2417
|
+
loss: Optional[torch.FloatTensor] = None
|
|
2418
|
+
logits: Optional[torch.FloatTensor] = None
|
|
2419
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
|
2420
|
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
2421
|
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
2422
|
+
attention_mask: Optional[torch.Tensor] = None
|
|
2423
|
+
rope_deltas: Optional[torch.LongTensor] = None
|
|
2424
|
+
thinker_reply_part: Optional[torch.Tensor] = None
|
|
2425
|
+
|
|
2426
|
+
|
|
2427
|
+
@add_start_docstrings(
|
|
2428
|
+
"The bare Qwen2.5OmniTalker Model outputting raw hidden-states without any specific head on top.",
|
|
2429
|
+
QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTalkerConfig"),
|
|
2430
|
+
)
|
|
2431
|
+
class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
|
2432
|
+
config_class = Qwen2_5OmniTalkerConfig
|
|
2433
|
+
_no_split_modules = ["Qwen2_5OmniTalkerDecoderLayer"]
|
|
2434
|
+
|
|
2435
|
+
def __init__(self, config: Qwen2_5OmniTalkerConfig):
|
|
2436
|
+
super().__init__(config)
|
|
2437
|
+
self.padding_idx = config.pad_token_id
|
|
2438
|
+
self.vocab_size = config.vocab_size
|
|
2439
|
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size, self.padding_idx)
|
|
2440
|
+
self.layers = nn.ModuleList(
|
|
2441
|
+
[Qwen2_5OmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
2442
|
+
)
|
|
2443
|
+
self._attn_implementation = config._attn_implementation
|
|
2444
|
+
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
2445
|
+
self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
|
|
2446
|
+
|
|
2447
|
+
self.gradient_checkpointing = False
|
|
2448
|
+
# Initialize weights and apply final processing
|
|
2449
|
+
self.post_init()
|
|
2450
|
+
|
|
2451
|
+
def get_input_embeddings(self):
|
|
2452
|
+
return self.embed_tokens
|
|
2453
|
+
|
|
2454
|
+
def set_input_embeddings(self, value):
|
|
2455
|
+
self.embed_tokens = value
|
|
2456
|
+
|
|
2457
|
+
def forward(
|
|
2458
|
+
self,
|
|
2459
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
2460
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
2461
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
2462
|
+
past_key_values: Optional[Any] = None,
|
|
2463
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
2464
|
+
use_cache: Optional[bool] = None,
|
|
2465
|
+
output_attentions: Optional[bool] = None,
|
|
2466
|
+
output_hidden_states: Optional[bool] = None,
|
|
2467
|
+
return_dict: Optional[bool] = None,
|
|
2468
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
2469
|
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
2470
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
2471
|
+
output_hidden_states = (
|
|
2472
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
2473
|
+
)
|
|
2474
|
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
2475
|
+
|
|
2476
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
2477
|
+
|
|
2478
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
2479
|
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
2480
|
+
|
|
2481
|
+
if self.gradient_checkpointing and self.training:
|
|
2482
|
+
if use_cache:
|
|
2483
|
+
logger.warning_once(
|
|
2484
|
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
2485
|
+
)
|
|
2486
|
+
use_cache = False
|
|
2487
|
+
|
|
2488
|
+
# torch.jit.trace() doesn't support cache objects in the output
|
|
2489
|
+
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
|
2490
|
+
past_key_values = DynamicCache()
|
|
2491
|
+
|
|
2492
|
+
if inputs_embeds is None:
|
|
2493
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
|
2494
|
+
|
|
2495
|
+
if cache_position is None:
|
|
2496
|
+
past_seen_tokens: Any
|
|
2497
|
+
if past_key_values is not None:
|
|
2498
|
+
past_seen_tokens = past_key_values.get_seq_length()
|
|
2499
|
+
else:
|
|
2500
|
+
past_seen_tokens = 0
|
|
2501
|
+
cache_position = torch.arange(
|
|
2502
|
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
2503
|
+
)
|
|
2504
|
+
|
|
2505
|
+
# the hard coded `3` is for temporal, height and width.
|
|
2506
|
+
if position_ids is None and cache_position is not None:
|
|
2507
|
+
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
|
2508
|
+
elif position_ids.dim() == 2:
|
|
2509
|
+
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
|
2510
|
+
assert attention_mask is not None and cache_position is not None
|
|
2511
|
+
causal_mask = self._update_causal_mask(
|
|
2512
|
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
|
2513
|
+
)
|
|
2514
|
+
|
|
2515
|
+
hidden_states = inputs_embeds
|
|
2516
|
+
|
|
2517
|
+
# create position embeddings to be shared across the decoder layers
|
|
2518
|
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
2519
|
+
|
|
2520
|
+
# decoder layers
|
|
2521
|
+
all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if output_hidden_states else None
|
|
2522
|
+
all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if output_attentions else None
|
|
2523
|
+
next_decoder_cache = None
|
|
2524
|
+
|
|
2525
|
+
for decoder_layer in self.layers:
|
|
2526
|
+
if output_hidden_states and all_hidden_states is not None and hidden_states is not None:
|
|
2527
|
+
all_hidden_states += (hidden_states,)
|
|
2528
|
+
|
|
2529
|
+
if self.gradient_checkpointing and self.training:
|
|
2530
|
+
layer_outputs = self._gradient_checkpointing_func(
|
|
2531
|
+
decoder_layer.__call__,
|
|
2532
|
+
hidden_states,
|
|
2533
|
+
causal_mask,
|
|
2534
|
+
position_ids,
|
|
2535
|
+
past_key_values,
|
|
2536
|
+
output_attentions,
|
|
2537
|
+
use_cache,
|
|
2538
|
+
cache_position,
|
|
2539
|
+
position_embeddings,
|
|
2540
|
+
)
|
|
2541
|
+
else:
|
|
2542
|
+
layer_outputs = decoder_layer(
|
|
2543
|
+
hidden_states,
|
|
2544
|
+
attention_mask=causal_mask,
|
|
2545
|
+
position_ids=position_ids,
|
|
2546
|
+
past_key_value=past_key_values,
|
|
2547
|
+
output_attentions=output_attentions,
|
|
2548
|
+
use_cache=use_cache,
|
|
2549
|
+
cache_position=cache_position,
|
|
2550
|
+
position_embeddings=position_embeddings,
|
|
2551
|
+
)
|
|
2552
|
+
|
|
2553
|
+
hidden_states = layer_outputs[0]
|
|
2554
|
+
|
|
2555
|
+
if use_cache:
|
|
2556
|
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
2557
|
+
|
|
2558
|
+
if output_attentions and all_self_attns is not None and layer_outputs is not None:
|
|
2559
|
+
all_self_attns += (layer_outputs[1],)
|
|
2560
|
+
|
|
2561
|
+
hidden_states = self.norm(hidden_states)
|
|
2562
|
+
|
|
2563
|
+
# add hidden states from the last decoder layer
|
|
2564
|
+
if output_hidden_states and all_hidden_states is not None and hidden_states is not None:
|
|
2565
|
+
all_hidden_states += (hidden_states,)
|
|
2566
|
+
|
|
2567
|
+
next_cache = next_decoder_cache if use_cache else None
|
|
2568
|
+
|
|
2569
|
+
if not return_dict:
|
|
2570
|
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
2571
|
+
return BaseModelOutputWithPast(
|
|
2572
|
+
last_hidden_state=hidden_states,
|
|
2573
|
+
past_key_values=next_cache,
|
|
2574
|
+
hidden_states=all_hidden_states,
|
|
2575
|
+
attentions=all_self_attns,
|
|
2576
|
+
)
|
|
2577
|
+
|
|
2578
|
+
def _update_causal_mask(
|
|
2579
|
+
self,
|
|
2580
|
+
attention_mask: torch.Tensor,
|
|
2581
|
+
input_tensor: torch.Tensor,
|
|
2582
|
+
cache_position: torch.Tensor,
|
|
2583
|
+
past_key_values: Cache,
|
|
2584
|
+
output_attentions: bool,
|
|
2585
|
+
):
|
|
2586
|
+
if self.config._attn_implementation == "flash_attention_2":
|
|
2587
|
+
if attention_mask is not None and past_key_values is not None:
|
|
2588
|
+
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
|
2589
|
+
if is_padding_right:
|
|
2590
|
+
raise ValueError(
|
|
2591
|
+
"You are attempting to perform batched generation with padding_side='right'"
|
|
2592
|
+
" this may lead to unexpected behaviour for Flash Attention version "
|
|
2593
|
+
"of Qwen25OmniTalker. Make sure to "
|
|
2594
|
+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
|
2595
|
+
)
|
|
2596
|
+
if attention_mask is not None and 0.0 in attention_mask:
|
|
2597
|
+
return attention_mask
|
|
2598
|
+
return None
|
|
2599
|
+
|
|
2600
|
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
2601
|
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
|
2602
|
+
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
|
2603
|
+
|
|
2604
|
+
if (
|
|
2605
|
+
self.config._attn_implementation == "sdpa"
|
|
2606
|
+
and not (using_static_cache or using_sliding_window_cache)
|
|
2607
|
+
and not output_attentions
|
|
2608
|
+
):
|
|
2609
|
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
|
2610
|
+
attention_mask,
|
|
2611
|
+
inputs_embeds=input_tensor,
|
|
2612
|
+
past_key_values_length=past_seen_tokens,
|
|
2613
|
+
sliding_window=self.config.sliding_window,
|
|
2614
|
+
is_training=self.training,
|
|
2615
|
+
):
|
|
2616
|
+
return None
|
|
2617
|
+
|
|
2618
|
+
dtype, device = input_tensor.dtype, input_tensor.device
|
|
2619
|
+
min_dtype = torch.finfo(dtype).min
|
|
2620
|
+
sequence_length = input_tensor.shape[1]
|
|
2621
|
+
# SlidingWindowCache or StaticCache
|
|
2622
|
+
if using_sliding_window_cache or using_static_cache:
|
|
2623
|
+
target_length = past_key_values.get_max_cache_shape()
|
|
2624
|
+
# DynamicCache or no cache
|
|
2625
|
+
else:
|
|
2626
|
+
target_length = (
|
|
2627
|
+
attention_mask.shape[-1]
|
|
2628
|
+
if isinstance(attention_mask, torch.Tensor)
|
|
2629
|
+
else past_seen_tokens + sequence_length + 1
|
|
2630
|
+
)
|
|
2631
|
+
|
|
2632
|
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
|
2633
|
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
|
2634
|
+
attention_mask,
|
|
2635
|
+
sequence_length=sequence_length,
|
|
2636
|
+
target_length=target_length,
|
|
2637
|
+
dtype=dtype,
|
|
2638
|
+
device=device,
|
|
2639
|
+
cache_position=cache_position,
|
|
2640
|
+
batch_size=input_tensor.shape[0],
|
|
2641
|
+
config=self.config,
|
|
2642
|
+
past_key_values=past_key_values,
|
|
2643
|
+
)
|
|
2644
|
+
|
|
2645
|
+
if (
|
|
2646
|
+
self.config._attn_implementation == "sdpa"
|
|
2647
|
+
and attention_mask is not None
|
|
2648
|
+
and attention_mask.device.type in ["cuda", "xpu"]
|
|
2649
|
+
and not output_attentions
|
|
2650
|
+
):
|
|
2651
|
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
|
2652
|
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
|
2653
|
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
|
2654
|
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
|
2655
|
+
|
|
2656
|
+
return causal_mask
|
|
2657
|
+
|
|
2658
|
+
@staticmethod
|
|
2659
|
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
2660
|
+
attention_mask: torch.Tensor,
|
|
2661
|
+
sequence_length: int,
|
|
2662
|
+
target_length: int,
|
|
2663
|
+
dtype: torch.dtype,
|
|
2664
|
+
device: torch.device,
|
|
2665
|
+
cache_position: torch.Tensor,
|
|
2666
|
+
batch_size: int,
|
|
2667
|
+
config: Qwen2_5OmniConfig,
|
|
2668
|
+
past_key_values: Cache,
|
|
2669
|
+
):
|
|
2670
|
+
if attention_mask is not None and attention_mask.dim() == 4:
|
|
2671
|
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
2672
|
+
causal_mask = attention_mask
|
|
2673
|
+
else:
|
|
2674
|
+
min_dtype = torch.finfo(dtype).min
|
|
2675
|
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
2676
|
+
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
|
2677
|
+
if config.sliding_window is not None:
|
|
2678
|
+
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
|
2679
|
+
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
|
2680
|
+
cache_position.reshape(-1, 1) - config.sliding_window
|
|
2681
|
+
)
|
|
2682
|
+
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
|
2683
|
+
causal_mask *= diagonal_attend_mask
|
|
2684
|
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
2685
|
+
if attention_mask is not None:
|
|
2686
|
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
2687
|
+
if attention_mask.shape[-1] > target_length:
|
|
2688
|
+
attention_mask = attention_mask[:, :target_length]
|
|
2689
|
+
mask_length = attention_mask.shape[-1]
|
|
2690
|
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
|
2691
|
+
causal_mask.device
|
|
2692
|
+
)
|
|
2693
|
+
padding_mask = padding_mask == 0
|
|
2694
|
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
2695
|
+
padding_mask, min_dtype
|
|
2696
|
+
)
|
|
2697
|
+
return causal_mask
|
|
2698
|
+
|
|
2699
|
+
|
|
2700
|
+
class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin):
|
|
2701
|
+
config_class = Qwen2_5OmniTalkerConfig
|
|
2702
|
+
|
|
2703
|
+
def __init__(self, config: Qwen2_5OmniTalkerConfig):
|
|
2704
|
+
super().__init__(config)
|
|
2705
|
+
|
|
2706
|
+
self.thinker_to_talker_proj = nn.Linear(config.embedding_size, config.hidden_size)
|
|
2707
|
+
|
|
2708
|
+
self.model = Qwen2_5OmniTalkerModel(config)
|
|
2709
|
+
self.codebook_size = config.vocab_size
|
|
2710
|
+
self.codec_head = nn.Linear(config.hidden_size, self.codebook_size, bias=False)
|
|
2711
|
+
|
|
2712
|
+
self.codec_bos_token = config.tts_codec_start_token_id
|
|
2713
|
+
self.codec_eos_token = config.tts_codec_end_token_id
|
|
2714
|
+
self.codec_pad_token = config.tts_codec_pad_token_id
|
|
2715
|
+
self.codec_mask_token = config.tts_codec_mask_token_id
|
|
2716
|
+
|
|
2717
|
+
self.text_bos_token = config.tts_text_start_token_id
|
|
2718
|
+
self.text_eos_token = config.tts_text_end_token_id
|
|
2719
|
+
self.text_pad_token = config.tts_text_pad_token_id
|
|
2720
|
+
|
|
2721
|
+
self.spatial_merge_size = self.config.spatial_merge_size
|
|
2722
|
+
|
|
2723
|
+
self.post_init()
|
|
2724
|
+
|
|
2725
|
+
def forward(
|
|
2726
|
+
self,
|
|
2727
|
+
input_ids: torch.LongTensor,
|
|
2728
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
2729
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
2730
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
2731
|
+
thinker_reply_part: Optional[torch.Tensor] = None,
|
|
2732
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
2733
|
+
rope_deltas: Optional[torch.LongTensor] = None,
|
|
2734
|
+
use_cache: Optional[bool] = None,
|
|
2735
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
2736
|
+
input_text_ids: Optional[torch.LongTensor] = None,
|
|
2737
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
2738
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
2739
|
+
use_audio_in_video: Optional[bool] = None,
|
|
2740
|
+
audio_feature_lengths: Optional[torch.LongTensor] = None,
|
|
2741
|
+
video_second_per_grid: Optional[torch.LongTensor] = None,
|
|
2742
|
+
output_attentions: Optional[bool] = None,
|
|
2743
|
+
output_hidden_states: Optional[bool] = None,
|
|
2744
|
+
return_dict: Optional[bool] = None,
|
|
2745
|
+
) -> Union[Tuple, Qwen2_5OmniTalkerCausalLMOutputWithPast]:
|
|
2746
|
+
|
|
2747
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
2748
|
+
output_hidden_states = (
|
|
2749
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
2750
|
+
)
|
|
2751
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
2752
|
+
|
|
2753
|
+
if attention_mask is not None and position_ids is None:
|
|
2754
|
+
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
|
|
2755
|
+
position_ids, rope_deltas = self.get_rope_index(
|
|
2756
|
+
input_text_ids,
|
|
2757
|
+
image_grid_thw,
|
|
2758
|
+
video_grid_thw,
|
|
2759
|
+
attention_mask,
|
|
2760
|
+
use_audio_in_video,
|
|
2761
|
+
audio_feature_lengths,
|
|
2762
|
+
video_second_per_grid,
|
|
2763
|
+
)
|
|
2764
|
+
assert inputs_embeds is not None
|
|
2765
|
+
inputs_embeds[:, -1, :] += self.get_input_embeddings()(
|
|
2766
|
+
torch.tensor([self.codec_bos_token], dtype=torch.long, device=inputs_embeds.device)
|
|
2767
|
+
)
|
|
2768
|
+
inputs_embeds[:, -2, :] += self.get_input_embeddings()(
|
|
2769
|
+
torch.tensor([self.codec_pad_token], dtype=torch.long, device=inputs_embeds.device)
|
|
2770
|
+
)
|
|
2771
|
+
|
|
2772
|
+
else:
|
|
2773
|
+
assert input_ids is not None
|
|
2774
|
+
batch_size, seq_length = input_ids.shape
|
|
2775
|
+
delta = (
|
|
2776
|
+
cache_position[0] + rope_deltas
|
|
2777
|
+
if cache_position is not None and rope_deltas is not None
|
|
2778
|
+
else torch.tensor(0, device=input_ids.device)
|
|
2779
|
+
)
|
|
2780
|
+
position_ids = torch.arange(seq_length, device=input_ids.device)
|
|
2781
|
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
2782
|
+
position_ids = position_ids.add(delta)
|
|
2783
|
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
2784
|
+
|
|
2785
|
+
if inputs_embeds is None:
|
|
2786
|
+
assert thinker_reply_part is not None
|
|
2787
|
+
# 1. 推理第 2 个以及之后的 token
|
|
2788
|
+
codec_embeds = self.get_input_embeddings()(input_ids)
|
|
2789
|
+
inputs_embeds = codec_embeds + thinker_reply_part[:, :1, :]
|
|
2790
|
+
if thinker_reply_part.shape[1] > 1:
|
|
2791
|
+
thinker_reply_part = thinker_reply_part[:, 1:, :]
|
|
2792
|
+
|
|
2793
|
+
talker_lm_input = self.thinker_to_talker_proj(inputs_embeds)
|
|
2794
|
+
|
|
2795
|
+
if attention_mask is not None:
|
|
2796
|
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
2797
|
+
|
|
2798
|
+
outputs = self.model(
|
|
2799
|
+
attention_mask=attention_mask,
|
|
2800
|
+
position_ids=position_ids,
|
|
2801
|
+
past_key_values=past_key_values,
|
|
2802
|
+
inputs_embeds=talker_lm_input,
|
|
2803
|
+
use_cache=use_cache,
|
|
2804
|
+
output_attentions=output_attentions,
|
|
2805
|
+
output_hidden_states=output_hidden_states,
|
|
2806
|
+
return_dict=return_dict,
|
|
2807
|
+
)
|
|
2808
|
+
|
|
2809
|
+
hidden_states = outputs[0]
|
|
2810
|
+
logits = self.codec_head(hidden_states)
|
|
2811
|
+
logits = logits.float()
|
|
2812
|
+
|
|
2813
|
+
loss = None
|
|
2814
|
+
|
|
2815
|
+
if not return_dict:
|
|
2816
|
+
output = (logits,) + outputs[1:]
|
|
2817
|
+
return (loss,) + output if loss is not None else output
|
|
2818
|
+
|
|
2819
|
+
return Qwen2_5OmniTalkerCausalLMOutputWithPast(
|
|
2820
|
+
loss=loss,
|
|
2821
|
+
logits=logits,
|
|
2822
|
+
past_key_values=outputs.past_key_values,
|
|
2823
|
+
hidden_states=hidden_states,
|
|
2824
|
+
attentions=outputs.attentions,
|
|
2825
|
+
attention_mask=attention_mask,
|
|
2826
|
+
rope_deltas=rope_deltas,
|
|
2827
|
+
thinker_reply_part=thinker_reply_part,
|
|
2828
|
+
)
|
|
2829
|
+
|
|
2830
|
+
def _get_initial_cache_position(self, input_ids, model_kwargs):
|
|
2831
|
+
# Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily
|
|
2832
|
+
inputs_embeds = model_kwargs.pop("inputs_embeds")
|
|
2833
|
+
model_kwargs = super()._get_initial_cache_position(input_ids, model_kwargs)
|
|
2834
|
+
model_kwargs["inputs_embeds"] = inputs_embeds
|
|
2835
|
+
return model_kwargs
|
|
2836
|
+
|
|
2837
|
+
# prepare inputs for talker lm generation
|
|
2838
|
+
def prepare_inputs_for_generation(
|
|
2839
|
+
self,
|
|
2840
|
+
input_ids,
|
|
2841
|
+
input_text_ids,
|
|
2842
|
+
past_key_values=None,
|
|
2843
|
+
attention_mask=None,
|
|
2844
|
+
inputs_embeds=None,
|
|
2845
|
+
thinker_reply_part=None,
|
|
2846
|
+
cache_position=None,
|
|
2847
|
+
position_ids=None,
|
|
2848
|
+
use_cache=True,
|
|
2849
|
+
pixel_values=None,
|
|
2850
|
+
pixel_values_videos=None,
|
|
2851
|
+
image_grid_thw=None,
|
|
2852
|
+
video_grid_thw=None,
|
|
2853
|
+
input_audio_features=None,
|
|
2854
|
+
audio_feature_attention_mask=None,
|
|
2855
|
+
audio_feature_lengths=None,
|
|
2856
|
+
use_audio_in_video=False,
|
|
2857
|
+
video_second_per_grid=None,
|
|
2858
|
+
**kwargs,
|
|
2859
|
+
):
|
|
2860
|
+
model_inputs = super().prepare_inputs_for_generation(
|
|
2861
|
+
input_ids,
|
|
2862
|
+
past_key_values,
|
|
2863
|
+
attention_mask,
|
|
2864
|
+
inputs_embeds,
|
|
2865
|
+
cache_position,
|
|
2866
|
+
use_cache=use_cache,
|
|
2867
|
+
thinker_reply_part=thinker_reply_part,
|
|
2868
|
+
input_text_ids=input_text_ids,
|
|
2869
|
+
image_grid_thw=image_grid_thw,
|
|
2870
|
+
video_grid_thw=video_grid_thw,
|
|
2871
|
+
use_audio_in_video=use_audio_in_video,
|
|
2872
|
+
audio_feature_lengths=audio_feature_lengths,
|
|
2873
|
+
video_second_per_grid=video_second_per_grid,
|
|
2874
|
+
**kwargs,
|
|
2875
|
+
)
|
|
2876
|
+
|
|
2877
|
+
model_inputs["position_ids"] = None
|
|
2878
|
+
|
|
2879
|
+
return model_inputs
|
|
2880
|
+
|
|
2881
|
+
def _update_model_kwargs_for_generation(
|
|
2882
|
+
self,
|
|
2883
|
+
outputs: ModelOutput,
|
|
2884
|
+
model_kwargs: Dict[str, Any],
|
|
2885
|
+
is_encoder_decoder: bool = False,
|
|
2886
|
+
num_new_tokens: int = 1,
|
|
2887
|
+
) -> Dict[str, Any]:
|
|
2888
|
+
# update attention_mask
|
|
2889
|
+
if getattr(outputs, "attention_mask", None) is not None:
|
|
2890
|
+
model_kwargs["attention_mask"] = outputs.attention_mask
|
|
2891
|
+
|
|
2892
|
+
model_kwargs = super()._update_model_kwargs_for_generation(
|
|
2893
|
+
outputs, model_kwargs, is_encoder_decoder, num_new_tokens
|
|
2894
|
+
)
|
|
2895
|
+
|
|
2896
|
+
if getattr(outputs, "rope_deltas", None) is not None:
|
|
2897
|
+
model_kwargs["rope_deltas"] = outputs.rope_deltas
|
|
2898
|
+
|
|
2899
|
+
if getattr(outputs, "thinker_reply_part", None) is not None:
|
|
2900
|
+
model_kwargs["thinker_reply_part"] = outputs.thinker_reply_part
|
|
2901
|
+
|
|
2902
|
+
return model_kwargs
|
|
2903
|
+
|
|
2904
|
+
|
|
2905
|
+
# Using custom RoPE, will use LlamaRotaryEmbedding next version
|
|
2906
|
+
class RotaryEmbedding(nn.Module):
|
|
2907
|
+
def __init__(self, dim, base=10000):
|
|
2908
|
+
super().__init__()
|
|
2909
|
+
|
|
2910
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
2911
|
+
self.register_buffer("inv_freq", inv_freq)
|
|
2912
|
+
|
|
2913
|
+
def forward(self, x):
|
|
2914
|
+
batch_size, seq_len = x.shape[0], x.shape[1]
|
|
2915
|
+
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
|
|
2916
|
+
freqs = torch.einsum("i , j -> i j", t.type_as(self.inv_freq), self.inv_freq)
|
|
2917
|
+
freqs = torch.stack((freqs, freqs), dim=-1)
|
|
2918
|
+
freqs = freqs.reshape(*freqs.shape[:-2], -1)
|
|
2919
|
+
freqs = freqs.repeat(batch_size, *([1] * freqs.dim()))
|
|
2920
|
+
|
|
2921
|
+
return freqs.cos(), freqs.sin()
|
|
2922
|
+
|
|
2923
|
+
|
|
2924
|
+
class TDNNBlock(nn.Module):
|
|
2925
|
+
def __init__(
|
|
2926
|
+
self,
|
|
2927
|
+
in_channels,
|
|
2928
|
+
out_channels,
|
|
2929
|
+
kernel_size,
|
|
2930
|
+
dilation,
|
|
2931
|
+
):
|
|
2932
|
+
super().__init__()
|
|
2933
|
+
self.conv = nn.Conv1d(
|
|
2934
|
+
in_channels=in_channels,
|
|
2935
|
+
out_channels=out_channels,
|
|
2936
|
+
kernel_size=kernel_size,
|
|
2937
|
+
dilation=dilation,
|
|
2938
|
+
padding="same",
|
|
2939
|
+
padding_mode="reflect",
|
|
2940
|
+
)
|
|
2941
|
+
self.activation = nn.ReLU()
|
|
2942
|
+
|
|
2943
|
+
def forward(self, x):
|
|
2944
|
+
return self.activation(self.conv(x))
|
|
2945
|
+
|
|
2946
|
+
|
|
2947
|
+
class Res2NetBlock(torch.nn.Module):
|
|
2948
|
+
"""An implementation of Res2NetBlock w/ dilation.
|
|
2949
|
+
|
|
2950
|
+
Arguments
|
|
2951
|
+
---------
|
|
2952
|
+
in_channels : int
|
|
2953
|
+
The number of channels expected in the input.
|
|
2954
|
+
out_channels : int
|
|
2955
|
+
The number of output channels.
|
|
2956
|
+
scale : int
|
|
2957
|
+
The scale of the Res2Net block.
|
|
2958
|
+
kernel_size: int
|
|
2959
|
+
The kernel size of the Res2Net block.
|
|
2960
|
+
dilation : int
|
|
2961
|
+
The dilation of the Res2Net block.
|
|
2962
|
+
"""
|
|
2963
|
+
|
|
2964
|
+
def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
|
|
2965
|
+
super().__init__()
|
|
2966
|
+
assert in_channels % scale == 0
|
|
2967
|
+
assert out_channels % scale == 0
|
|
2968
|
+
|
|
2969
|
+
in_channel = in_channels // scale
|
|
2970
|
+
hidden_channel = out_channels // scale
|
|
2971
|
+
|
|
2972
|
+
self.blocks = nn.ModuleList(
|
|
2973
|
+
[
|
|
2974
|
+
TDNNBlock(
|
|
2975
|
+
in_channel,
|
|
2976
|
+
hidden_channel,
|
|
2977
|
+
kernel_size=kernel_size,
|
|
2978
|
+
dilation=dilation,
|
|
2979
|
+
)
|
|
2980
|
+
for i in range(scale - 1)
|
|
2981
|
+
]
|
|
2982
|
+
)
|
|
2983
|
+
self.scale = scale
|
|
2984
|
+
|
|
2985
|
+
def forward(self, x):
|
|
2986
|
+
y = []
|
|
2987
|
+
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
|
|
2988
|
+
if i == 0:
|
|
2989
|
+
y_i = x_i
|
|
2990
|
+
elif i == 1:
|
|
2991
|
+
y_i = self.blocks[i - 1](x_i)
|
|
2992
|
+
else:
|
|
2993
|
+
y_i = self.blocks[i - 1](x_i + y_i)
|
|
2994
|
+
y.append(y_i)
|
|
2995
|
+
y_p = torch.cat(y, dim=1)
|
|
2996
|
+
return y_p
|
|
2997
|
+
|
|
2998
|
+
|
|
2999
|
+
class SEBlock(nn.Module):
|
|
3000
|
+
"""An implementation of squeeze-and-excitation block.
|
|
3001
|
+
|
|
3002
|
+
Arguments
|
|
3003
|
+
---------
|
|
3004
|
+
in_channels : int
|
|
3005
|
+
The number of input channels.
|
|
3006
|
+
se_channels : int
|
|
3007
|
+
The number of output channels after squeeze.
|
|
3008
|
+
out_channels : int
|
|
3009
|
+
The number of output channels.
|
|
3010
|
+
"""
|
|
3011
|
+
|
|
3012
|
+
def __init__(self, in_channels, se_channels, out_channels):
|
|
3013
|
+
super().__init__()
|
|
3014
|
+
|
|
3015
|
+
self.conv1 = nn.Conv1d(
|
|
3016
|
+
in_channels=in_channels,
|
|
3017
|
+
out_channels=se_channels,
|
|
3018
|
+
kernel_size=1,
|
|
3019
|
+
padding="same",
|
|
3020
|
+
padding_mode="reflect",
|
|
3021
|
+
)
|
|
3022
|
+
self.relu = nn.ReLU(inplace=True)
|
|
3023
|
+
self.conv2 = nn.Conv1d(
|
|
3024
|
+
in_channels=se_channels,
|
|
3025
|
+
out_channels=out_channels,
|
|
3026
|
+
kernel_size=1,
|
|
3027
|
+
padding="same",
|
|
3028
|
+
padding_mode="reflect",
|
|
3029
|
+
)
|
|
3030
|
+
self.sigmoid = nn.Sigmoid()
|
|
3031
|
+
|
|
3032
|
+
def forward(self, x):
|
|
3033
|
+
s = x.mean(dim=2, keepdim=True)
|
|
3034
|
+
|
|
3035
|
+
s = self.relu(self.conv1(s))
|
|
3036
|
+
s = self.sigmoid(self.conv2(s))
|
|
3037
|
+
|
|
3038
|
+
return s * x
|
|
3039
|
+
|
|
3040
|
+
|
|
3041
|
+
class AttentiveStatisticsPooling(nn.Module):
|
|
3042
|
+
"""This class implements an attentive statistic pooling layer for each channel.
|
|
3043
|
+
It returns the concatenated mean and std of the input tensor.
|
|
3044
|
+
|
|
3045
|
+
Arguments
|
|
3046
|
+
---------
|
|
3047
|
+
channels: int
|
|
3048
|
+
The number of input channels.
|
|
3049
|
+
attention_channels: int
|
|
3050
|
+
The number of attention channels.
|
|
3051
|
+
"""
|
|
3052
|
+
|
|
3053
|
+
def __init__(self, channels, attention_channels=128):
|
|
3054
|
+
super().__init__()
|
|
3055
|
+
|
|
3056
|
+
self.eps = 1e-12
|
|
3057
|
+
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
|
|
3058
|
+
self.tanh = nn.Tanh()
|
|
3059
|
+
self.conv = nn.Conv1d(
|
|
3060
|
+
in_channels=attention_channels,
|
|
3061
|
+
out_channels=channels,
|
|
3062
|
+
kernel_size=1,
|
|
3063
|
+
padding="same",
|
|
3064
|
+
padding_mode="reflect",
|
|
3065
|
+
)
|
|
3066
|
+
|
|
3067
|
+
def _length_to_mask(self, length, max_len=None, dtype=None, device=None):
|
|
3068
|
+
assert len(length.shape) == 1
|
|
3069
|
+
|
|
3070
|
+
if max_len is None:
|
|
3071
|
+
max_len = length.max().long().item() # using arange to generate mask
|
|
3072
|
+
mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
|
|
3073
|
+
len(length), max_len
|
|
3074
|
+
) < length.unsqueeze(1)
|
|
3075
|
+
|
|
3076
|
+
mask = torch.as_tensor(mask, dtype=dtype, device=device)
|
|
3077
|
+
return mask
|
|
3078
|
+
|
|
3079
|
+
def _compute_statistics(self, x, m, dim=2):
|
|
3080
|
+
mean = (m * x).sum(dim)
|
|
3081
|
+
std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps))
|
|
3082
|
+
return mean, std
|
|
3083
|
+
|
|
3084
|
+
def forward(self, x):
|
|
3085
|
+
"""Calculates mean and std for a batch (input tensor).
|
|
3086
|
+
|
|
3087
|
+
Arguments
|
|
3088
|
+
---------
|
|
3089
|
+
x : torch.Tensor
|
|
3090
|
+
Tensor of shape [N, C, L].
|
|
3091
|
+
"""
|
|
3092
|
+
L = x.shape[-1]
|
|
3093
|
+
|
|
3094
|
+
lengths = torch.ones(x.shape[0], device=x.device)
|
|
3095
|
+
|
|
3096
|
+
# Make binary mask of shape [N, 1, L]
|
|
3097
|
+
mask = self._length_to_mask(lengths * L, max_len=L, dtype=x.dtype, device=x.device)
|
|
3098
|
+
mask = mask.unsqueeze(1)
|
|
3099
|
+
|
|
3100
|
+
# Expand the temporal context of the pooling layer by allowing the
|
|
3101
|
+
# self-attention to look at global properties of the utterance.
|
|
3102
|
+
total = mask.sum(dim=2, keepdim=True)
|
|
3103
|
+
|
|
3104
|
+
mean, std = self._compute_statistics(x, mask / total)
|
|
3105
|
+
mean = mean.unsqueeze(2).repeat(1, 1, L)
|
|
3106
|
+
std = std.unsqueeze(2).repeat(1, 1, L)
|
|
3107
|
+
attn = torch.cat([x, mean, std], dim=1)
|
|
3108
|
+
|
|
3109
|
+
# Apply layers
|
|
3110
|
+
attn = self.conv(self.tanh(self.tdnn(attn)))
|
|
3111
|
+
|
|
3112
|
+
# Filter out zero-paddings
|
|
3113
|
+
attn = attn.masked_fill(mask == 0, float("-inf"))
|
|
3114
|
+
|
|
3115
|
+
attn = F.softmax(attn, dim=2)
|
|
3116
|
+
mean, std = self._compute_statistics(x, attn)
|
|
3117
|
+
# Append mean and std of the batch
|
|
3118
|
+
pooled_stats = torch.cat((mean, std), dim=1)
|
|
3119
|
+
pooled_stats = pooled_stats.unsqueeze(2)
|
|
3120
|
+
|
|
3121
|
+
return pooled_stats
|
|
3122
|
+
|
|
3123
|
+
|
|
3124
|
+
class SERes2NetBlock(nn.Module):
|
|
3125
|
+
"""An implementation of building block in ECAPA-TDNN, i.e.,
|
|
3126
|
+
TDNN-Res2Net-TDNN-SEBlock.
|
|
3127
|
+
|
|
3128
|
+
Arguments
|
|
3129
|
+
----------
|
|
3130
|
+
out_channels: int
|
|
3131
|
+
The number of output channels.
|
|
3132
|
+
res2net_scale: int
|
|
3133
|
+
The scale of the Res2Net block.
|
|
3134
|
+
kernel_size: int
|
|
3135
|
+
The kernel size of the TDNN blocks.
|
|
3136
|
+
dilation: int
|
|
3137
|
+
The dilation of the Res2Net block.
|
|
3138
|
+
activation : torch class
|
|
3139
|
+
A class for constructing the activation layers.
|
|
3140
|
+
"""
|
|
3141
|
+
|
|
3142
|
+
def __init__(
|
|
3143
|
+
self,
|
|
3144
|
+
in_channels,
|
|
3145
|
+
out_channels,
|
|
3146
|
+
res2net_scale=8,
|
|
3147
|
+
se_channels=128,
|
|
3148
|
+
kernel_size=1,
|
|
3149
|
+
dilation=1,
|
|
3150
|
+
):
|
|
3151
|
+
super().__init__()
|
|
3152
|
+
self.out_channels = out_channels
|
|
3153
|
+
self.tdnn1 = TDNNBlock(
|
|
3154
|
+
in_channels,
|
|
3155
|
+
out_channels,
|
|
3156
|
+
kernel_size=1,
|
|
3157
|
+
dilation=1,
|
|
3158
|
+
)
|
|
3159
|
+
self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
|
|
3160
|
+
self.tdnn2 = TDNNBlock(
|
|
3161
|
+
out_channels,
|
|
3162
|
+
out_channels,
|
|
3163
|
+
kernel_size=1,
|
|
3164
|
+
dilation=1,
|
|
3165
|
+
)
|
|
3166
|
+
self.se_block = SEBlock(out_channels, se_channels, out_channels)
|
|
3167
|
+
|
|
3168
|
+
def forward(self, x):
|
|
3169
|
+
residual = x
|
|
3170
|
+
|
|
3171
|
+
x = self.tdnn1(x)
|
|
3172
|
+
x = self.res2net_block(x)
|
|
3173
|
+
x = self.tdnn2(x)
|
|
3174
|
+
x = self.se_block(x)
|
|
3175
|
+
|
|
3176
|
+
return x + residual
|
|
3177
|
+
|
|
3178
|
+
|
|
3179
|
+
class ECAPA_TDNN(torch.nn.Module):
|
|
3180
|
+
|
|
3181
|
+
def __init__(self, config: Qwen2_5OmniDiTConfig):
|
|
3182
|
+
super().__init__()
|
|
3183
|
+
assert len(config.enc_channels) == len(config.enc_kernel_sizes)
|
|
3184
|
+
assert len(config.enc_channels) == len(config.enc_dilations)
|
|
3185
|
+
self.channels = config.enc_channels
|
|
3186
|
+
self.blocks = nn.ModuleList()
|
|
3187
|
+
|
|
3188
|
+
# The initial TDNN layer
|
|
3189
|
+
self.blocks.append(
|
|
3190
|
+
TDNNBlock(
|
|
3191
|
+
config.mel_dim,
|
|
3192
|
+
config.enc_channels[0],
|
|
3193
|
+
config.enc_kernel_sizes[0],
|
|
3194
|
+
config.enc_dilations[0],
|
|
3195
|
+
)
|
|
3196
|
+
)
|
|
3197
|
+
|
|
3198
|
+
# SE-Res2Net layers
|
|
3199
|
+
for i in range(1, len(config.enc_channels) - 1):
|
|
3200
|
+
self.blocks.append(
|
|
3201
|
+
SERes2NetBlock(
|
|
3202
|
+
config.enc_channels[i - 1],
|
|
3203
|
+
config.enc_channels[i],
|
|
3204
|
+
res2net_scale=config.enc_res2net_scale,
|
|
3205
|
+
se_channels=config.enc_se_channels,
|
|
3206
|
+
kernel_size=config.enc_kernel_sizes[i],
|
|
3207
|
+
dilation=config.enc_dilations[i],
|
|
3208
|
+
)
|
|
3209
|
+
)
|
|
3210
|
+
|
|
3211
|
+
# Multi-layer feature aggregation
|
|
3212
|
+
self.mfa = TDNNBlock(
|
|
3213
|
+
config.enc_channels[-1],
|
|
3214
|
+
config.enc_channels[-1],
|
|
3215
|
+
config.enc_kernel_sizes[-1],
|
|
3216
|
+
config.enc_dilations[-1],
|
|
3217
|
+
)
|
|
3218
|
+
|
|
3219
|
+
# Attentive Statistical Pooling
|
|
3220
|
+
self.asp = AttentiveStatisticsPooling(
|
|
3221
|
+
config.enc_channels[-1],
|
|
3222
|
+
attention_channels=config.enc_attention_channels,
|
|
3223
|
+
)
|
|
3224
|
+
|
|
3225
|
+
# Final linear transformation
|
|
3226
|
+
self.fc = nn.Conv1d(
|
|
3227
|
+
in_channels=config.enc_channels[-1] * 2,
|
|
3228
|
+
out_channels=config.enc_dim,
|
|
3229
|
+
kernel_size=1,
|
|
3230
|
+
padding="same",
|
|
3231
|
+
padding_mode="reflect",
|
|
3232
|
+
)
|
|
3233
|
+
|
|
3234
|
+
def forward(self, x):
|
|
3235
|
+
"""Returns the embedding vector.
|
|
3236
|
+
|
|
3237
|
+
Arguments
|
|
3238
|
+
---------
|
|
3239
|
+
x : torch.Tensor
|
|
3240
|
+
Tensor of shape (batch, time, channel).
|
|
3241
|
+
"""
|
|
3242
|
+
# Minimize transpose for efficiency
|
|
3243
|
+
x = x.transpose(1, 2)
|
|
3244
|
+
|
|
3245
|
+
xl = []
|
|
3246
|
+
for layer in self.blocks:
|
|
3247
|
+
x = layer(x)
|
|
3248
|
+
xl.append(x)
|
|
3249
|
+
|
|
3250
|
+
# Multi-layer feature aggregation
|
|
3251
|
+
x = torch.cat(xl[1:], dim=1)
|
|
3252
|
+
x = self.mfa(x)
|
|
3253
|
+
|
|
3254
|
+
# Attentive Statistical Pooling
|
|
3255
|
+
x = self.asp(x)
|
|
3256
|
+
|
|
3257
|
+
# Final linear transformation
|
|
3258
|
+
x = self.fc(x)
|
|
3259
|
+
|
|
3260
|
+
x = x.squeeze(-1)
|
|
3261
|
+
return x
|
|
3262
|
+
|
|
3263
|
+
|
|
3264
|
+
class InputEmbedding(nn.Module):
|
|
3265
|
+
def __init__(self, config: Qwen2_5OmniDiTConfig):
|
|
3266
|
+
super().__init__()
|
|
3267
|
+
self.proj = nn.Linear(
|
|
3268
|
+
config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim,
|
|
3269
|
+
config.hidden_size,
|
|
3270
|
+
)
|
|
3271
|
+
self.spk_encoder = ECAPA_TDNN(config)
|
|
3272
|
+
|
|
3273
|
+
def forward(self, x, spk, cond, code_embed, drop_audio_cond=False, code_embed_uncond=None, cfg=True):
|
|
3274
|
+
if cfg:
|
|
3275
|
+
x = torch.cat([x, x], dim=0)
|
|
3276
|
+
spk = torch.cat([spk, torch.zeros_like(spk)], dim=0)
|
|
3277
|
+
cond = torch.cat([cond, torch.zeros_like(cond)], dim=0)
|
|
3278
|
+
code_embed = torch.cat([code_embed, code_embed_uncond], dim=0)
|
|
3279
|
+
elif drop_audio_cond: # cfg for cond audio
|
|
3280
|
+
cond = torch.zeros_like(cond)
|
|
3281
|
+
spk = torch.zeros_like(spk)
|
|
3282
|
+
cond = self.spk_encoder(cond).unsqueeze(1).repeat(1, x.size(1), 1)
|
|
3283
|
+
x = self.proj(torch.cat((x, cond, code_embed, spk), dim=-1))
|
|
3284
|
+
|
|
3285
|
+
return x
|
|
3286
|
+
|
|
3287
|
+
|
|
3288
|
+
# Transformer backbone using DiT blocks
|
|
3289
|
+
class CodecEmbedding(nn.Module):
|
|
3290
|
+
def __init__(self, codec_num_embeds, codec_dim, repeats):
|
|
3291
|
+
super().__init__()
|
|
3292
|
+
self.repeats = repeats
|
|
3293
|
+
self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim)
|
|
3294
|
+
|
|
3295
|
+
def forward(self, code, drop_code=False):
|
|
3296
|
+
if drop_code:
|
|
3297
|
+
code = torch.zeros_like(code)
|
|
3298
|
+
code_embed = self.codec_embed(code)
|
|
3299
|
+
|
|
3300
|
+
code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1)
|
|
3301
|
+
return code_embed
|
|
3302
|
+
|
|
3303
|
+
|
|
3304
|
+
# AdaLayerNormZero
|
|
3305
|
+
# return with modulated x for attn input, and params for later mlp modulation
|
|
3306
|
+
class AdaLayerNormZero(nn.Module):
|
|
3307
|
+
def __init__(self, dim):
|
|
3308
|
+
super().__init__()
|
|
3309
|
+
|
|
3310
|
+
self.silu = nn.SiLU()
|
|
3311
|
+
self.linear = nn.Linear(dim, dim * 6)
|
|
3312
|
+
|
|
3313
|
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
3314
|
+
|
|
3315
|
+
def forward(self, x, emb=None):
|
|
3316
|
+
emb = self.linear(self.silu(emb))
|
|
3317
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
|
3318
|
+
|
|
3319
|
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
|
3320
|
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
|
3321
|
+
|
|
3322
|
+
|
|
3323
|
+
# AdaLayerNormZero for final layer
|
|
3324
|
+
# return only with modulated x for attn input, cuz no more mlp modulation
|
|
3325
|
+
class AdaLayerNormZero_Final(nn.Module):
|
|
3326
|
+
def __init__(self, dim):
|
|
3327
|
+
super().__init__()
|
|
3328
|
+
|
|
3329
|
+
self.silu = nn.SiLU()
|
|
3330
|
+
self.linear = nn.Linear(dim, dim * 2)
|
|
3331
|
+
|
|
3332
|
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
3333
|
+
|
|
3334
|
+
def forward(self, x, emb):
|
|
3335
|
+
emb = self.linear(self.silu(emb))
|
|
3336
|
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
|
3337
|
+
|
|
3338
|
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
|
3339
|
+
return x
|
|
3340
|
+
|
|
3341
|
+
|
|
3342
|
+
# FeedForward
|
|
3343
|
+
class FeedForward(nn.Module):
|
|
3344
|
+
def __init__(self, dim, mult=4, dropout=0.0):
|
|
3345
|
+
super().__init__()
|
|
3346
|
+
inner_dim = int(dim * mult)
|
|
3347
|
+
|
|
3348
|
+
self.ff = nn.ModuleList(
|
|
3349
|
+
[
|
|
3350
|
+
nn.Linear(dim, inner_dim),
|
|
3351
|
+
nn.GELU(approximate="tanh"),
|
|
3352
|
+
nn.Dropout(dropout),
|
|
3353
|
+
nn.Linear(inner_dim, dim),
|
|
3354
|
+
]
|
|
3355
|
+
)
|
|
3356
|
+
|
|
3357
|
+
def forward(self, x):
|
|
3358
|
+
for layer in self.ff:
|
|
3359
|
+
x = layer(x)
|
|
3360
|
+
return x
|
|
3361
|
+
|
|
3362
|
+
|
|
3363
|
+
# Modified from Llama with a different rotate function, will fixed in next release
|
|
3364
|
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
3365
|
+
|
|
3366
|
+
def rotate_half_codec(x):
|
|
3367
|
+
# x = rearrange(x, "... (d r) -> ... d r", r=2)
|
|
3368
|
+
x = x.reshape(*x.shape[:-1], -1, 2)
|
|
3369
|
+
x1, x2 = x.unbind(dim=-1)
|
|
3370
|
+
x = torch.stack((-x2, x1), dim=-1)
|
|
3371
|
+
return x.reshape(*x.shape[:-2], -1)
|
|
3372
|
+
|
|
3373
|
+
cos = cos.unsqueeze(unsqueeze_dim)
|
|
3374
|
+
sin = sin.unsqueeze(unsqueeze_dim)
|
|
3375
|
+
q_embed = (q * cos) + (rotate_half_codec(q) * sin)
|
|
3376
|
+
k_embed = (k * cos) + (rotate_half_codec(k) * sin)
|
|
3377
|
+
return q_embed, k_embed
|
|
3378
|
+
|
|
3379
|
+
|
|
3380
|
+
class DiTAttention(nn.Module):
|
|
3381
|
+
def __init__(self, config: Qwen2_5OmniDiTConfig):
|
|
3382
|
+
super().__init__()
|
|
3383
|
+
|
|
3384
|
+
self.config = config
|
|
3385
|
+
self.dim = config.hidden_size
|
|
3386
|
+
self.heads = config.num_attention_heads
|
|
3387
|
+
self.inner_dim = config.head_dim * config.num_attention_heads
|
|
3388
|
+
self.dropout = config.dropout
|
|
3389
|
+
self._attn_implementation = config._attn_implementation
|
|
3390
|
+
self.is_causal = False
|
|
3391
|
+
|
|
3392
|
+
self.to_q = nn.Linear(config.hidden_size, self.inner_dim)
|
|
3393
|
+
self.to_k = nn.Linear(config.hidden_size, self.inner_dim)
|
|
3394
|
+
self.to_v = nn.Linear(config.hidden_size, self.inner_dim)
|
|
3395
|
+
|
|
3396
|
+
self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)])
|
|
3397
|
+
|
|
3398
|
+
def forward(
|
|
3399
|
+
self,
|
|
3400
|
+
x, # noised input x
|
|
3401
|
+
rope=None, # rotary position embedding for x
|
|
3402
|
+
mask=None,
|
|
3403
|
+
) -> torch.Tensor:
|
|
3404
|
+
batch_size = x.shape[0]
|
|
3405
|
+
|
|
3406
|
+
# `sample` projections.
|
|
3407
|
+
query = self.to_q(x)
|
|
3408
|
+
key = self.to_k(x)
|
|
3409
|
+
value = self.to_v(x)
|
|
3410
|
+
|
|
3411
|
+
# attention
|
|
3412
|
+
inner_dim = key.shape[-1]
|
|
3413
|
+
head_dim = inner_dim // self.heads
|
|
3414
|
+
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
3415
|
+
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
3416
|
+
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
3417
|
+
|
|
3418
|
+
# apply rotary position embedding
|
|
3419
|
+
# Due to training process, only first head is applied with RoPE, will be fixed at next release
|
|
3420
|
+
cos, sin = rope
|
|
3421
|
+
query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin)
|
|
3422
|
+
|
|
3423
|
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation]
|
|
3424
|
+
x, _ = attention_interface(
|
|
3425
|
+
self,
|
|
3426
|
+
query,
|
|
3427
|
+
key,
|
|
3428
|
+
value,
|
|
3429
|
+
attention_mask=mask,
|
|
3430
|
+
is_causal=False,
|
|
3431
|
+
)
|
|
3432
|
+
|
|
3433
|
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
|
3434
|
+
# x = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
|
3435
|
+
x = x.reshape(batch_size, -1, self.heads * head_dim)
|
|
3436
|
+
x = x.to(query.dtype)
|
|
3437
|
+
|
|
3438
|
+
# linear proj
|
|
3439
|
+
x = self.to_out[0](x)
|
|
3440
|
+
# dropout
|
|
3441
|
+
x = self.to_out[1](x)
|
|
3442
|
+
|
|
3443
|
+
return x
|
|
3444
|
+
|
|
3445
|
+
|
|
3446
|
+
# time step conditioning embedding
|
|
3447
|
+
class SinusPositionEmbedding(nn.Module):
|
|
3448
|
+
def __init__(self, dim):
|
|
3449
|
+
super().__init__()
|
|
3450
|
+
self.dim = dim
|
|
3451
|
+
|
|
3452
|
+
def forward(self, x, scale=1000):
|
|
3453
|
+
device = x.device
|
|
3454
|
+
half_dim = self.dim // 2
|
|
3455
|
+
emb = math.log(10000) / (half_dim - 1)
|
|
3456
|
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
|
3457
|
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
|
3458
|
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
3459
|
+
return emb.type_as(x)
|
|
3460
|
+
|
|
3461
|
+
|
|
3462
|
+
class TimestepEmbedding(nn.Module):
|
|
3463
|
+
def __init__(self, dim, freq_embed_dim=256):
|
|
3464
|
+
super().__init__()
|
|
3465
|
+
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
|
3466
|
+
self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)])
|
|
3467
|
+
|
|
3468
|
+
def forward(self, timestep): # noqa: F821
|
|
3469
|
+
time_hidden = self.time_embed(timestep)
|
|
3470
|
+
time_hidden = time_hidden.to(timestep.dtype)
|
|
3471
|
+
for layer in self.time_mlp:
|
|
3472
|
+
time_hidden = layer(time_hidden) # b d
|
|
3473
|
+
return time_hidden
|
|
3474
|
+
|
|
3475
|
+
|
|
3476
|
+
class DiTBlock(nn.Module):
|
|
3477
|
+
def __init__(self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0):
|
|
3478
|
+
super().__init__()
|
|
3479
|
+
self.attn_norm = AdaLayerNormZero(config.hidden_size)
|
|
3480
|
+
|
|
3481
|
+
self.attn = DiTAttention(config)
|
|
3482
|
+
self.look_ahead_block = look_ahead_block
|
|
3483
|
+
self.look_backward_block = look_backward_block
|
|
3484
|
+
self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6)
|
|
3485
|
+
self.ff = FeedForward(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout)
|
|
3486
|
+
|
|
3487
|
+
def forward(self, x, t, rope=None, block_diff=None): # x: noised input, t: time embedding
|
|
3488
|
+
# pre-norm & modulation for attention input
|
|
3489
|
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
|
3490
|
+
|
|
3491
|
+
# attention
|
|
3492
|
+
attn_output = self.attn(
|
|
3493
|
+
x=norm,
|
|
3494
|
+
rope=rope,
|
|
3495
|
+
mask=(block_diff >= -float(self.look_backward_block)) & (block_diff <= float(self.look_ahead_block)),
|
|
3496
|
+
)
|
|
3497
|
+
|
|
3498
|
+
# process attention output for input x
|
|
3499
|
+
x = x + gate_msa.unsqueeze(1) * attn_output
|
|
3500
|
+
|
|
3501
|
+
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
3502
|
+
ff_output = self.ff(norm)
|
|
3503
|
+
x = x + gate_mlp.unsqueeze(1) * ff_output
|
|
3504
|
+
|
|
3505
|
+
return x
|
|
3506
|
+
|
|
3507
|
+
|
|
3508
|
+
class SnakeBeta(nn.Module):
|
|
3509
|
+
|
|
3510
|
+
def __init__(self, in_features, alpha=1.0):
|
|
3511
|
+
super().__init__()
|
|
3512
|
+
self.in_features = in_features
|
|
3513
|
+
|
|
3514
|
+
# initialize alpha
|
|
3515
|
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
|
3516
|
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
|
3517
|
+
|
|
3518
|
+
self.no_div_by_zero = 0.000000001
|
|
3519
|
+
|
|
3520
|
+
def forward(self, x):
|
|
3521
|
+
"""
|
|
3522
|
+
Forward pass of the function.
|
|
3523
|
+
Applies the function to the input elementwise.
|
|
3524
|
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
|
3525
|
+
"""
|
|
3526
|
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
|
3527
|
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
|
3528
|
+
alpha = torch.exp(alpha)
|
|
3529
|
+
beta = torch.exp(beta)
|
|
3530
|
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
|
3531
|
+
|
|
3532
|
+
return x
|
|
3533
|
+
|
|
3534
|
+
|
|
3535
|
+
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
|
3536
|
+
even = kernel_size % 2 == 0
|
|
3537
|
+
half_size = kernel_size // 2
|
|
3538
|
+
|
|
3539
|
+
# For kaiser window
|
|
3540
|
+
delta_f = 4 * half_width
|
|
3541
|
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
|
3542
|
+
if A > 50.0:
|
|
3543
|
+
beta = 0.1102 * (A - 8.7)
|
|
3544
|
+
elif A >= 21.0:
|
|
3545
|
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
|
3546
|
+
else:
|
|
3547
|
+
beta = 0.0
|
|
3548
|
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32)
|
|
3549
|
+
|
|
3550
|
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
|
3551
|
+
if even:
|
|
3552
|
+
time = torch.arange(-half_size, half_size) + 0.5
|
|
3553
|
+
else:
|
|
3554
|
+
time = torch.arange(kernel_size) - half_size
|
|
3555
|
+
if cutoff == 0:
|
|
3556
|
+
filter_ = torch.zeros_like(time)
|
|
3557
|
+
else:
|
|
3558
|
+
filter_ = 2 * cutoff * window * torch.sinc(2 * cutoff * time)
|
|
3559
|
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
|
3560
|
+
# of the constant component in the input signal.
|
|
3561
|
+
filter_ /= filter_.sum()
|
|
3562
|
+
filter = filter_.view(1, 1, kernel_size)
|
|
3563
|
+
|
|
3564
|
+
return filter
|
|
3565
|
+
|
|
3566
|
+
|
|
3567
|
+
class UpSample1d(nn.Module):
|
|
3568
|
+
def __init__(self, ratio=2, kernel_size=None):
|
|
3569
|
+
super().__init__()
|
|
3570
|
+
self.ratio = ratio
|
|
3571
|
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
|
3572
|
+
self.stride = ratio
|
|
3573
|
+
self.pad = self.kernel_size // ratio - 1
|
|
3574
|
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
|
3575
|
+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
|
3576
|
+
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
|
|
3577
|
+
self.register_buffer("filter", filter, persistent=False)
|
|
3578
|
+
|
|
3579
|
+
# x: [B, C, T]
|
|
3580
|
+
def forward(self, x):
|
|
3581
|
+
_, C, _ = x.shape
|
|
3582
|
+
|
|
3583
|
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
|
3584
|
+
x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
|
3585
|
+
x = x[..., self.pad_left : -self.pad_right]
|
|
3586
|
+
|
|
3587
|
+
return x
|
|
3588
|
+
|
|
3589
|
+
|
|
3590
|
+
class DownSample1d(nn.Module):
|
|
3591
|
+
def __init__(self, ratio=2, kernel_size=None):
|
|
3592
|
+
super().__init__()
|
|
3593
|
+
cutoff = 0.5 / ratio
|
|
3594
|
+
half_width = 0.6 / ratio
|
|
3595
|
+
if cutoff < -0.0:
|
|
3596
|
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
|
3597
|
+
if cutoff > 0.5:
|
|
3598
|
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
|
3599
|
+
self.kernel_size = kernel_size
|
|
3600
|
+
self.even = kernel_size % 2 == 0
|
|
3601
|
+
self.pad_left = kernel_size // 2 - int(self.even)
|
|
3602
|
+
self.pad_right = kernel_size // 2
|
|
3603
|
+
self.stride = ratio
|
|
3604
|
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
|
3605
|
+
self.register_buffer("filter", filter, persistent=False)
|
|
3606
|
+
|
|
3607
|
+
def forward(self, x):
|
|
3608
|
+
_, C, _ = x.shape
|
|
3609
|
+
|
|
3610
|
+
x = F.pad(x, (self.pad_left, self.pad_right), mode="replicate")
|
|
3611
|
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
|
3612
|
+
|
|
3613
|
+
return out
|
|
3614
|
+
|
|
3615
|
+
|
|
3616
|
+
class TorchActivation1d(nn.Module):
|
|
3617
|
+
def __init__(
|
|
3618
|
+
self,
|
|
3619
|
+
activation,
|
|
3620
|
+
up_ratio: int = 2,
|
|
3621
|
+
down_ratio: int = 2,
|
|
3622
|
+
up_kernel_size: int = 12,
|
|
3623
|
+
down_kernel_size: int = 12,
|
|
3624
|
+
):
|
|
3625
|
+
super().__init__()
|
|
3626
|
+
self.up_ratio = up_ratio
|
|
3627
|
+
self.down_ratio = down_ratio
|
|
3628
|
+
self.act = activation
|
|
3629
|
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
|
3630
|
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
|
3631
|
+
|
|
3632
|
+
# x: [B,C,T]
|
|
3633
|
+
def forward(self, x):
|
|
3634
|
+
x = self.upsample(x)
|
|
3635
|
+
x = self.act(x)
|
|
3636
|
+
x = self.downsample(x)
|
|
3637
|
+
|
|
3638
|
+
return x
|
|
3639
|
+
|
|
3640
|
+
|
|
3641
|
+
class AMPBlock(torch.nn.Module):
|
|
3642
|
+
def __init__(
|
|
3643
|
+
self,
|
|
3644
|
+
channels,
|
|
3645
|
+
kernel_size=3,
|
|
3646
|
+
dilation=(1, 3, 5),
|
|
3647
|
+
):
|
|
3648
|
+
super().__init__()
|
|
3649
|
+
|
|
3650
|
+
self.convs1 = nn.ModuleList(
|
|
3651
|
+
[
|
|
3652
|
+
nn.Conv1d(
|
|
3653
|
+
channels,
|
|
3654
|
+
channels,
|
|
3655
|
+
kernel_size,
|
|
3656
|
+
1,
|
|
3657
|
+
dilation=dilation[0],
|
|
3658
|
+
padding=self._get_padding(kernel_size, dilation[0]),
|
|
3659
|
+
),
|
|
3660
|
+
nn.Conv1d(
|
|
3661
|
+
channels,
|
|
3662
|
+
channels,
|
|
3663
|
+
kernel_size,
|
|
3664
|
+
1,
|
|
3665
|
+
dilation=dilation[1],
|
|
3666
|
+
padding=self._get_padding(kernel_size, dilation[1]),
|
|
3667
|
+
),
|
|
3668
|
+
nn.Conv1d(
|
|
3669
|
+
channels,
|
|
3670
|
+
channels,
|
|
3671
|
+
kernel_size,
|
|
3672
|
+
1,
|
|
3673
|
+
dilation=dilation[2],
|
|
3674
|
+
padding=self._get_padding(kernel_size, dilation[2]),
|
|
3675
|
+
),
|
|
3676
|
+
]
|
|
3677
|
+
)
|
|
3678
|
+
|
|
3679
|
+
self.convs2 = nn.ModuleList(
|
|
3680
|
+
[
|
|
3681
|
+
nn.Conv1d(
|
|
3682
|
+
channels,
|
|
3683
|
+
channels,
|
|
3684
|
+
kernel_size,
|
|
3685
|
+
1,
|
|
3686
|
+
dilation=1,
|
|
3687
|
+
padding=self._get_padding(kernel_size, 1),
|
|
3688
|
+
),
|
|
3689
|
+
nn.Conv1d(
|
|
3690
|
+
channels,
|
|
3691
|
+
channels,
|
|
3692
|
+
kernel_size,
|
|
3693
|
+
1,
|
|
3694
|
+
dilation=1,
|
|
3695
|
+
padding=self._get_padding(kernel_size, 1),
|
|
3696
|
+
),
|
|
3697
|
+
nn.Conv1d(
|
|
3698
|
+
channels,
|
|
3699
|
+
channels,
|
|
3700
|
+
kernel_size,
|
|
3701
|
+
1,
|
|
3702
|
+
dilation=1,
|
|
3703
|
+
padding=self._get_padding(kernel_size, 1),
|
|
3704
|
+
),
|
|
3705
|
+
]
|
|
3706
|
+
)
|
|
3707
|
+
|
|
3708
|
+
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
|
3709
|
+
|
|
3710
|
+
self.activations = nn.ModuleList(
|
|
3711
|
+
[TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)]
|
|
3712
|
+
)
|
|
3713
|
+
|
|
3714
|
+
def _get_padding(self, kernel_size, dilation=1):
|
|
3715
|
+
return int((kernel_size * dilation - dilation) / 2)
|
|
3716
|
+
|
|
3717
|
+
def forward(self, x):
|
|
3718
|
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
|
3719
|
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
|
3720
|
+
xt = a1(x)
|
|
3721
|
+
xt = c1(xt)
|
|
3722
|
+
xt = a2(xt)
|
|
3723
|
+
xt = c2(xt)
|
|
3724
|
+
x = xt + x
|
|
3725
|
+
|
|
3726
|
+
return x
|
|
3727
|
+
|
|
3728
|
+
|
|
3729
|
+
class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel):
|
|
3730
|
+
config_class = Qwen2_5OmniBigVGANConfig
|
|
3731
|
+
|
|
3732
|
+
def __init__(self, config: Qwen2_5OmniBigVGANConfig):
|
|
3733
|
+
super().__init__(config)
|
|
3734
|
+
|
|
3735
|
+
self.num_kernels = len(config.resblock_kernel_sizes)
|
|
3736
|
+
self.num_upsamples = len(config.upsample_rates)
|
|
3737
|
+
|
|
3738
|
+
# pre conv
|
|
3739
|
+
self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3)
|
|
3740
|
+
|
|
3741
|
+
# transposed conv-based upsamplers. does not apply anti-aliasing
|
|
3742
|
+
self.ups = nn.ModuleList()
|
|
3743
|
+
for i, (u, k) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
|
|
3744
|
+
self.ups.append(
|
|
3745
|
+
nn.ModuleList(
|
|
3746
|
+
[
|
|
3747
|
+
ConvTranspose1d(
|
|
3748
|
+
config.upsample_initial_channel // (2**i),
|
|
3749
|
+
config.upsample_initial_channel // (2 ** (i + 1)),
|
|
3750
|
+
k,
|
|
3751
|
+
u,
|
|
3752
|
+
padding=(k - u) // 2,
|
|
3753
|
+
)
|
|
3754
|
+
]
|
|
3755
|
+
)
|
|
3756
|
+
)
|
|
3757
|
+
|
|
3758
|
+
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
|
3759
|
+
self.resblocks = nn.ModuleList()
|
|
3760
|
+
for i in range(len(self.ups)):
|
|
3761
|
+
ch = config.upsample_initial_channel // (2 ** (i + 1))
|
|
3762
|
+
for j, (k, d) in enumerate(zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes)):
|
|
3763
|
+
self.resblocks.append(AMPBlock(ch, k, d))
|
|
3764
|
+
|
|
3765
|
+
# post conv
|
|
3766
|
+
self.activation_post = TorchActivation1d(activation=SnakeBeta(ch))
|
|
3767
|
+
|
|
3768
|
+
self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
|
3769
|
+
|
|
3770
|
+
def _normalize(self, S, max_abs_value, min_db):
|
|
3771
|
+
return torch.clamp(
|
|
3772
|
+
(2 * max_abs_value) * ((S - min_db) / (-min_db)) - max_abs_value, -max_abs_value, max_abs_value
|
|
3773
|
+
)
|
|
3774
|
+
|
|
3775
|
+
def _amp_to_db(self, x, min_level_db):
|
|
3776
|
+
min_level = np.exp(min_level_db / 20 * np.log(10))
|
|
3777
|
+
min_level = torch.ones_like(x) * min_level
|
|
3778
|
+
return 20 * torch.log10(torch.maximum(min_level, x))
|
|
3779
|
+
|
|
3780
|
+
def apm_to_db(self, apm_mel):
|
|
3781
|
+
mel_spec = torch.exp(apm_mel)
|
|
3782
|
+
|
|
3783
|
+
mel_spec = self._amp_to_db(mel_spec, -115) - 20
|
|
3784
|
+
mel_spec = self._normalize(mel_spec, 1, -115)
|
|
3785
|
+
|
|
3786
|
+
return mel_spec
|
|
3787
|
+
|
|
3788
|
+
def forward(self, apm_mel):
|
|
3789
|
+
mel_spec = self.apm_to_db(apm_mel)
|
|
3790
|
+
# pre conv
|
|
3791
|
+
hidden = self.conv_pre(mel_spec)
|
|
3792
|
+
|
|
3793
|
+
for i in range(self.num_upsamples):
|
|
3794
|
+
# upsampling
|
|
3795
|
+
for i_up in range(len(self.ups[i])):
|
|
3796
|
+
ups_i = cast(nn.Sequential, self.ups[i])
|
|
3797
|
+
hidden = ups_i[i_up](hidden)
|
|
3798
|
+
# AMP blocks
|
|
3799
|
+
xs = None
|
|
3800
|
+
for j in range(self.num_kernels):
|
|
3801
|
+
if xs is None:
|
|
3802
|
+
xs = self.resblocks[i * self.num_kernels + j](hidden)
|
|
3803
|
+
else:
|
|
3804
|
+
xs += self.resblocks[i * self.num_kernels + j](hidden)
|
|
3805
|
+
assert xs is not None
|
|
3806
|
+
hidden = xs / self.num_kernels
|
|
3807
|
+
|
|
3808
|
+
# post conv
|
|
3809
|
+
hidden = self.activation_post(hidden)
|
|
3810
|
+
hidden = self.conv_post(hidden)
|
|
3811
|
+
audio = torch.clamp(hidden, min=-1.0, max=1.0) # bound the output to [-1, 1]
|
|
3812
|
+
|
|
3813
|
+
return audio.squeeze().cpu()
|
|
3814
|
+
|
|
3815
|
+
|
|
3816
|
+
class ODESolverRK4:
|
|
3817
|
+
def __init__(self, func, y0):
|
|
3818
|
+
self.func = func
|
|
3819
|
+
self.y0 = y0
|
|
3820
|
+
|
|
3821
|
+
self._one_third = 1 / 3
|
|
3822
|
+
self._two_thirds = 2 / 3
|
|
3823
|
+
|
|
3824
|
+
def _rk4_alt_step_func(self, func, t0, dt, t1, y0, f0=None):
|
|
3825
|
+
k1 = f0
|
|
3826
|
+
if k1 is None:
|
|
3827
|
+
k1 = func(t0, y0)
|
|
3828
|
+
k2 = func(t0 + dt * self._one_third, y0 + dt * k1 * self._one_third)
|
|
3829
|
+
k3 = func(t0 + dt * self._two_thirds, y0 + dt * (k2 - k1 * self._one_third))
|
|
3830
|
+
k4 = func(t1, y0 + dt * (k1 - k2 + k3))
|
|
3831
|
+
return (k1 + 3 * (k2 + k3) + k4) * dt * 0.125
|
|
3832
|
+
|
|
3833
|
+
def _step_func(self, func, t0, dt, t1, y0):
|
|
3834
|
+
f0 = func(t0, y0)
|
|
3835
|
+
return self._rk4_alt_step_func(func, t0, dt, t1, y0, f0=f0), f0
|
|
3836
|
+
|
|
3837
|
+
def _linear_interp(self, t0, t1, y0, y1, t):
|
|
3838
|
+
if t == t0:
|
|
3839
|
+
return y0
|
|
3840
|
+
if t == t1:
|
|
3841
|
+
return y1
|
|
3842
|
+
slope = (t - t0) / (t1 - t0)
|
|
3843
|
+
return y0 + slope * (y1 - y0)
|
|
3844
|
+
|
|
3845
|
+
def integrate(self, t):
|
|
3846
|
+
solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
|
|
3847
|
+
solution[0] = self.y0
|
|
3848
|
+
|
|
3849
|
+
j = 1
|
|
3850
|
+
y0 = self.y0
|
|
3851
|
+
for t0, t1 in zip(t[:-1], t[1:]):
|
|
3852
|
+
dt = t1 - t0
|
|
3853
|
+
dy, f0 = self._step_func(self.func, t0, dt, t1, y0)
|
|
3854
|
+
y1 = y0 + dy
|
|
3855
|
+
|
|
3856
|
+
while j < len(t) and t1 >= t[j]:
|
|
3857
|
+
solution[j] = self._linear_interp(t0, t1, y0, y1, t[j])
|
|
3858
|
+
j += 1
|
|
3859
|
+
y0 = y1
|
|
3860
|
+
|
|
3861
|
+
return solution
|
|
3862
|
+
|
|
3863
|
+
|
|
3864
|
+
class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel):
|
|
3865
|
+
config_class = Qwen2_5OmniDiTConfig
|
|
3866
|
+
_no_split_modules = ["DiTBlock"]
|
|
3867
|
+
|
|
3868
|
+
def __init__(self, config: Qwen2_5OmniDiTConfig):
|
|
3869
|
+
super().__init__(config)
|
|
3870
|
+
self.mel_dim = config.mel_dim
|
|
3871
|
+
self.repeats = config.repeats
|
|
3872
|
+
self.time_embed = TimestepEmbedding(config.hidden_size)
|
|
3873
|
+
|
|
3874
|
+
self.text_embed = CodecEmbedding(config.num_embeds, config.emb_dim, config.repeats)
|
|
3875
|
+
self.input_embed = InputEmbedding(config)
|
|
3876
|
+
|
|
3877
|
+
self.rotary_embed = RotaryEmbedding(config.head_dim)
|
|
3878
|
+
# self.rotary_embed = Qwen2_5OmniDiTRotaryEmbedding(config)
|
|
3879
|
+
|
|
3880
|
+
self.hidden_size = config.hidden_size
|
|
3881
|
+
self.layers = config.num_hidden_layers
|
|
3882
|
+
self.block_size = config.block_size
|
|
3883
|
+
self.num_attention_heads = config.num_attention_heads
|
|
3884
|
+
|
|
3885
|
+
self.transformer_blocks = nn.ModuleList()
|
|
3886
|
+
for i in range(config.num_hidden_layers):
|
|
3887
|
+
self.transformer_blocks.append(
|
|
3888
|
+
DiTBlock(
|
|
3889
|
+
config,
|
|
3890
|
+
look_ahead_block=1 if i in config.look_ahead_layers else 0,
|
|
3891
|
+
look_backward_block=1 if i in config.look_backward_layers else 0,
|
|
3892
|
+
)
|
|
3893
|
+
)
|
|
3894
|
+
|
|
3895
|
+
self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
|
|
3896
|
+
self.proj_out = nn.Linear(config.hidden_size, config.mel_dim)
|
|
3897
|
+
|
|
3898
|
+
def _create_block_diff(self, x):
|
|
3899
|
+
batch, seq_len = x.shape[0], x.shape[1]
|
|
3900
|
+
block_indices = torch.arange(seq_len, device=x.device) // self.block_size # [seq_length]
|
|
3901
|
+
|
|
3902
|
+
block_i = block_indices.unsqueeze(1) # [seq_length, 1]
|
|
3903
|
+
block_j = block_indices.unsqueeze(0) # [1, seq_length]
|
|
3904
|
+
|
|
3905
|
+
block_diff = block_j - block_i # (n, n)
|
|
3906
|
+
|
|
3907
|
+
return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len)
|
|
3908
|
+
|
|
3909
|
+
def forward(
|
|
3910
|
+
self,
|
|
3911
|
+
x, # nosied input audio
|
|
3912
|
+
cond, # masked cond audio
|
|
3913
|
+
spk, # spk embedding
|
|
3914
|
+
code, # code
|
|
3915
|
+
time, # time step # noqa: F821 F722
|
|
3916
|
+
drop_audio_cond=False, # cfg for cond audio
|
|
3917
|
+
drop_code=False, # cfg for code
|
|
3918
|
+
cfg=True,
|
|
3919
|
+
):
|
|
3920
|
+
batch = x.shape[0]
|
|
3921
|
+
if time.ndim == 0:
|
|
3922
|
+
time = time.repeat(batch)
|
|
3923
|
+
|
|
3924
|
+
# t: conditioning time, c: context (code + masked cond audio), x: noised input audio
|
|
3925
|
+
t = self.time_embed(time)
|
|
3926
|
+
code_embed = self.text_embed(code, drop_code=False if cfg else drop_code)
|
|
3927
|
+
code_embed_uncond = self.text_embed(code, drop_code=True) if cfg else None
|
|
3928
|
+
hidden = self.input_embed(
|
|
3929
|
+
x,
|
|
3930
|
+
spk,
|
|
3931
|
+
cond,
|
|
3932
|
+
code_embed,
|
|
3933
|
+
drop_audio_cond=drop_audio_cond,
|
|
3934
|
+
code_embed_uncond=code_embed_uncond,
|
|
3935
|
+
cfg=cfg,
|
|
3936
|
+
)
|
|
3937
|
+
|
|
3938
|
+
# rope = self.rotary_embed(x, torch.arange(seq_len, device=x.device).repeat(batch, 1))
|
|
3939
|
+
rope = self.rotary_embed(hidden)
|
|
3940
|
+
|
|
3941
|
+
block_diff = self._create_block_diff(hidden)
|
|
3942
|
+
|
|
3943
|
+
for block in self.transformer_blocks:
|
|
3944
|
+
hidden = block(hidden, t, rope=rope, block_diff=block_diff)
|
|
3945
|
+
|
|
3946
|
+
hidden = self.norm_out(hidden, t)
|
|
3947
|
+
output = self.proj_out(hidden)
|
|
3948
|
+
|
|
3949
|
+
return output
|
|
3950
|
+
|
|
3951
|
+
@torch.no_grad()
|
|
3952
|
+
def sample(
|
|
3953
|
+
self,
|
|
3954
|
+
cond,
|
|
3955
|
+
ref_mel,
|
|
3956
|
+
code,
|
|
3957
|
+
steps=10,
|
|
3958
|
+
cfg_strength=0.5,
|
|
3959
|
+
sway_sampling_coef=-1.0,
|
|
3960
|
+
):
|
|
3961
|
+
y_all = torch.randn([1, 30000, self.mel_dim], dtype=ref_mel.dtype)
|
|
3962
|
+
max_duration = code.shape[1] * self.repeats
|
|
3963
|
+
y0 = y_all[:, :max_duration].to(code.device)
|
|
3964
|
+
batch = ref_mel.shape[0]
|
|
3965
|
+
cond = cond.unsqueeze(1).repeat(1, max_duration, 1)
|
|
3966
|
+
assert batch == 1, "only support batch size = 1 currently"
|
|
3967
|
+
|
|
3968
|
+
def fn(t, x):
|
|
3969
|
+
if cfg_strength < 1e-5:
|
|
3970
|
+
pred = self(x=x, spk=cond, cond=ref_mel, code=code, time=t, drop_audio_cond=False, drop_code=False)
|
|
3971
|
+
return pred
|
|
3972
|
+
|
|
3973
|
+
out_put = self(x=x, code=code, spk=cond, cond=ref_mel, time=t, cfg=True)
|
|
3974
|
+
pred, null_pred = torch.chunk(out_put, 2, dim=0)
|
|
3975
|
+
|
|
3976
|
+
return pred + (pred - null_pred) * cfg_strength
|
|
3977
|
+
|
|
3978
|
+
t_start = 0
|
|
3979
|
+
t = torch.linspace(t_start, 1, steps, device=code.device, dtype=cond.dtype)
|
|
3980
|
+
if sway_sampling_coef is not None:
|
|
3981
|
+
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
|
3982
|
+
|
|
3983
|
+
solver = ODESolverRK4(func=fn, y0=y0)
|
|
3984
|
+
trajectory = solver.integrate(t)
|
|
3985
|
+
|
|
3986
|
+
generated = trajectory[-1]
|
|
3987
|
+
generated_mel_spec = generated.permute(0, 2, 1)
|
|
3988
|
+
return generated_mel_spec
|
|
3989
|
+
|
|
3990
|
+
|
|
3991
|
+
@add_start_docstrings(
|
|
3992
|
+
(
|
|
3993
|
+
"The full Qwen2.5Omni Token2Wav model. Consists a DiT model take speech"
|
|
3994
|
+
" tokens as input and predict mel spectrogram and a BigVGAN vocoder take"
|
|
3995
|
+
" mel spectrogram as input and predict waveform."
|
|
3996
|
+
),
|
|
3997
|
+
QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniToken2WavConfig"),
|
|
3998
|
+
)
|
|
3999
|
+
class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel):
|
|
4000
|
+
config_class = Qwen2_5OmniToken2WavConfig
|
|
4001
|
+
base_model_prefix = "model"
|
|
4002
|
+
_no_split_modules = ["Qwen2_5OmniToken2WavDiTModel", "Qwen2_5OmniToken2WavBigVGANModel"]
|
|
4003
|
+
|
|
4004
|
+
def __init__(self, config: Qwen2_5OmniToken2WavConfig):
|
|
4005
|
+
super().__init__(config)
|
|
4006
|
+
attn_impl = config._attn_implementation
|
|
4007
|
+
if config._attn_implementation == "flash_attention_2":
|
|
4008
|
+
logger.warning_once(
|
|
4009
|
+
"Qwen2_5OmniToken2WavModel must inference with fp32, but "
|
|
4010
|
+
"flash_attention_2 only supports fp16 and bf16, "
|
|
4011
|
+
"attention implementation of Qwen2_5OmniToken2WavModel will fallback to sdpa."
|
|
4012
|
+
)
|
|
4013
|
+
attn_impl = "sdpa"
|
|
4014
|
+
elif config._attn_implementation == "eager":
|
|
4015
|
+
logger.warning_once(
|
|
4016
|
+
"Qwen2_5OmniToken2WavModel does not support eager attention implementation, " "fall back to sdpa"
|
|
4017
|
+
)
|
|
4018
|
+
attn_impl = "sdpa"
|
|
4019
|
+
self.code2wav_dit_model = Qwen2_5OmniToken2WavDiTModel._from_config(
|
|
4020
|
+
config.dit_config, attn_implementation=attn_impl
|
|
4021
|
+
)
|
|
4022
|
+
self.code2wav_bigvgan_model = Qwen2_5OmniToken2WavBigVGANModel._from_config(
|
|
4023
|
+
config.bigvgan_config, attn_implementation=attn_impl
|
|
4024
|
+
)
|
|
4025
|
+
|
|
4026
|
+
def forward(
|
|
4027
|
+
self,
|
|
4028
|
+
code,
|
|
4029
|
+
cond,
|
|
4030
|
+
ref_mel,
|
|
4031
|
+
steps=10,
|
|
4032
|
+
cfg_strength=0.5,
|
|
4033
|
+
sway_sampling_coef=-1.0,
|
|
4034
|
+
**kwargs,
|
|
4035
|
+
):
|
|
4036
|
+
generated_mel = self.code2wav_dit_model.sample(
|
|
4037
|
+
cond,
|
|
4038
|
+
ref_mel,
|
|
4039
|
+
code,
|
|
4040
|
+
steps=steps,
|
|
4041
|
+
cfg_strength=cfg_strength,
|
|
4042
|
+
sway_sampling_coef=sway_sampling_coef,
|
|
4043
|
+
)
|
|
4044
|
+
waveform = self.code2wav_bigvgan_model(generated_mel)
|
|
4045
|
+
return waveform
|
|
4046
|
+
|
|
4047
|
+
|
|
4048
|
+
@add_start_docstrings(
|
|
4049
|
+
"""""",
|
|
4050
|
+
QWEN2_5OMNI_START_DOCSTRING.format(config_class=Qwen2_5OmniConfig),
|
|
4051
|
+
)
|
|
4052
|
+
class Qwen2_5OmniModel(Qwen2_5OmniPreTrainedModel):
|
|
4053
|
+
config_class = Qwen2_5OmniConfig
|
|
4054
|
+
_no_split_modules = [
|
|
4055
|
+
"Qwen2_5OmniTalkerForConditionalGeneration",
|
|
4056
|
+
"Qwen2_5OmniToken2WavModel",
|
|
4057
|
+
]
|
|
4058
|
+
|
|
4059
|
+
def __init__(self, config):
|
|
4060
|
+
super().__init__(config)
|
|
4061
|
+
|
|
4062
|
+
self.thinker = Qwen2_5OmniThinkerForConditionalGeneration(config.thinker_config)
|
|
4063
|
+
|
|
4064
|
+
self.has_talker = config.enable_audio_output
|
|
4065
|
+
self.speaker_map = {}
|
|
4066
|
+
if config.enable_audio_output:
|
|
4067
|
+
self.enable_talker()
|
|
4068
|
+
|
|
4069
|
+
def enable_talker(self):
|
|
4070
|
+
self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config)
|
|
4071
|
+
self.token2wav = Qwen2_5OmniToken2WavModel(self.config.token2wav_config)
|
|
4072
|
+
self.token2wav.float()
|
|
4073
|
+
self.has_talker = True
|
|
4074
|
+
|
|
4075
|
+
def load_speakers(self, path):
|
|
4076
|
+
for key, value in torch.load(path).items():
|
|
4077
|
+
self.speaker_map[key] = value
|
|
4078
|
+
logger.info("Speaker {} loaded".format(list(self.speaker_map.keys())))
|
|
4079
|
+
|
|
4080
|
+
def disable_talker(self):
|
|
4081
|
+
if hasattr(self, "talker"):
|
|
4082
|
+
del self.talker
|
|
4083
|
+
if hasattr(self, "token2wav"):
|
|
4084
|
+
del self.token2wav
|
|
4085
|
+
self.has_talker = False
|
|
4086
|
+
|
|
4087
|
+
@classmethod
|
|
4088
|
+
def can_generate(cls) -> bool:
|
|
4089
|
+
return True
|
|
4090
|
+
|
|
4091
|
+
@classmethod
|
|
4092
|
+
def from_pretrained(
|
|
4093
|
+
cls,
|
|
4094
|
+
pretrained_model_name_or_path,
|
|
4095
|
+
*model_args,
|
|
4096
|
+
config=None,
|
|
4097
|
+
cache_dir=None,
|
|
4098
|
+
ignore_mismatched_sizes=False,
|
|
4099
|
+
force_download=False,
|
|
4100
|
+
local_files_only=False,
|
|
4101
|
+
token=None,
|
|
4102
|
+
revision="main",
|
|
4103
|
+
use_safetensors=None,
|
|
4104
|
+
weights_only=True,
|
|
4105
|
+
**kwargs,
|
|
4106
|
+
):
|
|
4107
|
+
model = super().from_pretrained(
|
|
4108
|
+
pretrained_model_name_or_path,
|
|
4109
|
+
*model_args,
|
|
4110
|
+
config=config,
|
|
4111
|
+
cache_dir=cache_dir,
|
|
4112
|
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
|
4113
|
+
force_download=force_download,
|
|
4114
|
+
local_files_only=local_files_only,
|
|
4115
|
+
token=token,
|
|
4116
|
+
revision=revision,
|
|
4117
|
+
use_safetensors=use_safetensors,
|
|
4118
|
+
weights_only=weights_only,
|
|
4119
|
+
**kwargs,
|
|
4120
|
+
)
|
|
4121
|
+
spk_path = cached_file(
|
|
4122
|
+
pretrained_model_name_or_path,
|
|
4123
|
+
"spk_dict.pt",
|
|
4124
|
+
subfolder=kwargs.pop("subfolder", None),
|
|
4125
|
+
cache_dir=kwargs.pop("cache_dir", None),
|
|
4126
|
+
force_download=kwargs.pop("force_download", False),
|
|
4127
|
+
proxies=kwargs.pop("proxies", None),
|
|
4128
|
+
resume_download=kwargs.pop("resume_download", None),
|
|
4129
|
+
local_files_only=kwargs.pop("local_files_only", False),
|
|
4130
|
+
token=kwargs.pop("use_auth_token", None),
|
|
4131
|
+
revision=kwargs.pop("revision", None),
|
|
4132
|
+
)
|
|
4133
|
+
if spk_path is None:
|
|
4134
|
+
raise ValueError(f"""{pretrained_model_name_or_path}/{spk_path} not exists""")
|
|
4135
|
+
model.load_speakers(spk_path)
|
|
4136
|
+
|
|
4137
|
+
return model
|
|
4138
|
+
|
|
4139
|
+
@torch.no_grad()
|
|
4140
|
+
def generate(
|
|
4141
|
+
self,
|
|
4142
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
4143
|
+
spk: str = "Chelsie",
|
|
4144
|
+
use_audio_in_video: bool = False,
|
|
4145
|
+
return_audio: Optional[bool] = None,
|
|
4146
|
+
thinker_max_new_tokens: int = 1024,
|
|
4147
|
+
talker_max_new_tokens: int = 4096,
|
|
4148
|
+
talker_do_sample: bool = True,
|
|
4149
|
+
talker_top_k: int = 40,
|
|
4150
|
+
talker_top_p: float = 0.8,
|
|
4151
|
+
talker_temperature: float = 0.9,
|
|
4152
|
+
talker_eos_token_id: list[int] = [8292, 8294],
|
|
4153
|
+
talker_repetition_penalty: float = 1.05,
|
|
4154
|
+
**kwargs,
|
|
4155
|
+
):
|
|
4156
|
+
if spk not in self.speaker_map:
|
|
4157
|
+
raise ValueError(f"{spk} is not availible, availible speakers: {self.speaker_map.keys()}")
|
|
4158
|
+
if return_audio and not self.has_talker:
|
|
4159
|
+
raise ValueError(
|
|
4160
|
+
"Cannot use talker when talker module not initalized. Use `enable_talker` "
|
|
4161
|
+
"method or set enable_talker in config to enable talker."
|
|
4162
|
+
)
|
|
4163
|
+
if return_audio is None:
|
|
4164
|
+
return_audio = self.has_talker
|
|
4165
|
+
assert input_ids is not None
|
|
4166
|
+
if input_ids.shape[0] != 1 and return_audio:
|
|
4167
|
+
raise NotImplementedError("Qwen2.5-Omni currently does not support batched inference with audio output")
|
|
4168
|
+
shared_kwargs = {"use_audio_in_video": use_audio_in_video}
|
|
4169
|
+
thinker_kwargs = {
|
|
4170
|
+
"max_new_tokens": thinker_max_new_tokens,
|
|
4171
|
+
}
|
|
4172
|
+
talker_kwargs: dict[str, Union[torch.Tensor, Any]] = {
|
|
4173
|
+
"max_new_tokens": talker_max_new_tokens,
|
|
4174
|
+
"do_sample": talker_do_sample,
|
|
4175
|
+
"top_k": talker_top_k,
|
|
4176
|
+
"top_p": talker_top_p,
|
|
4177
|
+
"temperature": talker_temperature,
|
|
4178
|
+
"eos_token_id": talker_eos_token_id,
|
|
4179
|
+
"repetition_penalty": talker_repetition_penalty,
|
|
4180
|
+
}
|
|
4181
|
+
token2wav_kwargs = {}
|
|
4182
|
+
|
|
4183
|
+
for key, value in kwargs.items():
|
|
4184
|
+
if key.startswith("thinker_"):
|
|
4185
|
+
thinker_kwargs[key[len("thinker_") :]] = value
|
|
4186
|
+
elif key.startswith("talker_"):
|
|
4187
|
+
talker_kwargs[key[len("talker_") :]] = value
|
|
4188
|
+
elif key.startswith("token2wav_"):
|
|
4189
|
+
token2wav_kwargs[key[len("token2wav_") :]] = value
|
|
4190
|
+
# Process special input values
|
|
4191
|
+
elif key == "feature_attention_mask":
|
|
4192
|
+
thinker_kwargs[key] = value
|
|
4193
|
+
talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1)
|
|
4194
|
+
elif key == "input_features" or key == "attention_mask":
|
|
4195
|
+
thinker_kwargs[key] = value
|
|
4196
|
+
# Put other key to shared kwargs
|
|
4197
|
+
else:
|
|
4198
|
+
shared_kwargs[key] = value
|
|
4199
|
+
# Merge kwargs
|
|
4200
|
+
for key, value in shared_kwargs.items():
|
|
4201
|
+
if key not in thinker_kwargs:
|
|
4202
|
+
thinker_kwargs[key] = value
|
|
4203
|
+
if key not in talker_kwargs:
|
|
4204
|
+
talker_kwargs[key] = value
|
|
4205
|
+
if key not in token2wav_kwargs:
|
|
4206
|
+
token2wav_kwargs[key] = value
|
|
4207
|
+
speaker_params = self.speaker_map[spk]
|
|
4208
|
+
|
|
4209
|
+
# 1. Generate from thinker module
|
|
4210
|
+
thinker_result = self.thinker.generate(
|
|
4211
|
+
input_ids=input_ids,
|
|
4212
|
+
return_dict_in_generate=True,
|
|
4213
|
+
output_hidden_states=True,
|
|
4214
|
+
**thinker_kwargs,
|
|
4215
|
+
)
|
|
4216
|
+
if not (return_audio and self.has_talker):
|
|
4217
|
+
return thinker_result.sequences
|
|
4218
|
+
|
|
4219
|
+
# 2. Generate speech tokens from talker module
|
|
4220
|
+
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device)
|
|
4221
|
+
thinker_token_embeds = [x[0].to(self.talker.device) for x in thinker_result.hidden_states]
|
|
4222
|
+
thinker_hidden_states = [x[1][-1].to(self.talker.device) for x in thinker_result.hidden_states]
|
|
4223
|
+
|
|
4224
|
+
talker_text_bos_token = speaker_params["bos_token"]
|
|
4225
|
+
talker_input_text_ids = torch.cat(
|
|
4226
|
+
[
|
|
4227
|
+
input_ids.to(self.talker.device),
|
|
4228
|
+
torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.talker.device),
|
|
4229
|
+
thinker_generate_ids[:, :1],
|
|
4230
|
+
],
|
|
4231
|
+
dim=-1,
|
|
4232
|
+
)
|
|
4233
|
+
|
|
4234
|
+
talker_input_ids = torch.cat(
|
|
4235
|
+
[
|
|
4236
|
+
torch.full_like(input_ids, fill_value=self.talker.codec_mask_token, device=self.talker.device),
|
|
4237
|
+
torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=self.talker.device),
|
|
4238
|
+
torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=self.talker.device),
|
|
4239
|
+
],
|
|
4240
|
+
dim=1,
|
|
4241
|
+
)
|
|
4242
|
+
|
|
4243
|
+
thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
|
4244
|
+
talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0]
|
|
4245
|
+
talker_inputs_embeds = torch.cat(
|
|
4246
|
+
[
|
|
4247
|
+
talker_inputs_embeds,
|
|
4248
|
+
self.thinker.get_input_embeddings()(
|
|
4249
|
+
torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device)
|
|
4250
|
+
).to(self.talker.device),
|
|
4251
|
+
thinker_reply_part[:, :1, :],
|
|
4252
|
+
],
|
|
4253
|
+
dim=1,
|
|
4254
|
+
)
|
|
4255
|
+
|
|
4256
|
+
thinker_reply_part = torch.cat(
|
|
4257
|
+
[
|
|
4258
|
+
thinker_reply_part[:, 1:, :],
|
|
4259
|
+
self.thinker.get_input_embeddings()(
|
|
4260
|
+
torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device)
|
|
4261
|
+
).to(self.talker.device),
|
|
4262
|
+
self.thinker.get_input_embeddings()(
|
|
4263
|
+
torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device)
|
|
4264
|
+
).to(self.talker.device),
|
|
4265
|
+
],
|
|
4266
|
+
dim=1,
|
|
4267
|
+
)
|
|
4268
|
+
|
|
4269
|
+
talker_attention_mask = torch.cat(
|
|
4270
|
+
[kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1
|
|
4271
|
+
).to(self.talker.device)
|
|
4272
|
+
|
|
4273
|
+
talker_result = self.talker.generate(
|
|
4274
|
+
input_ids=talker_input_ids,
|
|
4275
|
+
input_text_ids=talker_input_text_ids,
|
|
4276
|
+
thinker_reply_part=thinker_reply_part,
|
|
4277
|
+
inputs_embeds=talker_inputs_embeds,
|
|
4278
|
+
attention_mask=talker_attention_mask,
|
|
4279
|
+
suppress_tokens=[self.talker.codec_bos_token],
|
|
4280
|
+
**{k: (v.to(self.talker.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()},
|
|
4281
|
+
)
|
|
4282
|
+
talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1]
|
|
4283
|
+
|
|
4284
|
+
# 3. Generate wavs from code
|
|
4285
|
+
if self.token2wav.dtype != torch.float:
|
|
4286
|
+
self.token2wav.float()
|
|
4287
|
+
wav = self.token2wav(
|
|
4288
|
+
talker_generate_codes.to(self.token2wav.device),
|
|
4289
|
+
cond=speaker_params["cond"].to(self.token2wav.device).float(),
|
|
4290
|
+
ref_mel=speaker_params["ref_mel"].to(self.token2wav.device).float(),
|
|
4291
|
+
**token2wav_kwargs,
|
|
4292
|
+
)
|
|
4293
|
+
|
|
4294
|
+
return thinker_result.sequences, wav.float()
|
|
4295
|
+
|
|
4296
|
+
|
|
4297
|
+
__all__ = [
|
|
4298
|
+
"Qwen2_5OmniModel",
|
|
4299
|
+
"Qwen2_5OmniThinkerModel",
|
|
4300
|
+
"Qwen2_5OmniThinkerForConditionalGeneration",
|
|
4301
|
+
"Qwen2_5OmniTalkerModel",
|
|
4302
|
+
"Qwen2_5OmniTalkerForConditionalGeneration",
|
|
4303
|
+
"Qwen2_5OmniToken2WavDiTModel",
|
|
4304
|
+
"Qwen2_5OmniToken2WavBigVGANModel",
|
|
4305
|
+
"Qwen2_5OmniToken2WavModel",
|
|
4306
|
+
"Qwen2_5OmniPreTrainedModel",
|
|
4307
|
+
"Qwen2_5OmniPreTrainedModelForConditionalGeneration",
|
|
4308
|
+
]
|