transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
import math
|
|
18
18
|
from dataclasses import dataclass
|
|
19
|
-
from typing import Optional
|
|
19
|
+
from typing import Optional, Union
|
|
20
20
|
|
|
21
21
|
import torch
|
|
22
22
|
from torch import nn
|
|
@@ -462,7 +462,7 @@ class TvpEncoder(nn.Module):
|
|
|
462
462
|
output_attentions: Optional[bool] = None,
|
|
463
463
|
output_hidden_states: Optional[bool] = None,
|
|
464
464
|
return_dict: Optional[bool] = None,
|
|
465
|
-
):
|
|
465
|
+
) -> Union[tuple, BaseModelOutput]:
|
|
466
466
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
467
467
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
468
468
|
output_hidden_states = (
|
|
@@ -722,7 +722,7 @@ class TvpModel(TvpPreTrainedModel):
|
|
|
722
722
|
return_dict: Optional[bool] = None,
|
|
723
723
|
interpolate_pos_encoding: bool = False,
|
|
724
724
|
**kwargs,
|
|
725
|
-
):
|
|
725
|
+
) -> Union[tuple, BaseModelOutputWithPooling]:
|
|
726
726
|
r"""
|
|
727
727
|
Examples:
|
|
728
728
|
```python
|
|
@@ -824,7 +824,7 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
|
|
|
824
824
|
return_dict: Optional[bool] = None,
|
|
825
825
|
interpolate_pos_encoding: bool = False,
|
|
826
826
|
**kwargs,
|
|
827
|
-
):
|
|
827
|
+
) -> Union[tuple, TvpVideoGroundingOutput]:
|
|
828
828
|
r"""
|
|
829
829
|
labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
|
|
830
830
|
The labels contains duration, start time, and end time of the video corresponding to the text.
|
|
@@ -1106,7 +1106,7 @@ class UdopStack(UdopPreTrainedModel):
|
|
|
1106
1106
|
return_dict=None,
|
|
1107
1107
|
cache_position=None,
|
|
1108
1108
|
**kwargs,
|
|
1109
|
-
):
|
|
1109
|
+
) -> Union[tuple, BaseModelOutputWithAttentionMask]:
|
|
1110
1110
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
1111
1111
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1112
1112
|
output_hidden_states = (
|
|
@@ -1436,12 +1436,10 @@ class UdopModel(UdopPreTrainedModel):
|
|
|
1436
1436
|
encoder_config = deepcopy(config)
|
|
1437
1437
|
encoder_config.is_decoder = False
|
|
1438
1438
|
encoder_config.use_cache = False
|
|
1439
|
-
encoder_config.tie_word_embeddings = True
|
|
1440
1439
|
self.encoder = UdopStack(encoder_config)
|
|
1441
1440
|
|
|
1442
1441
|
decoder_config = deepcopy(config)
|
|
1443
1442
|
decoder_config.is_decoder = True
|
|
1444
|
-
decoder_config.tie_word_embeddings = True
|
|
1445
1443
|
decoder_config.num_layers = config.num_decoder_layers
|
|
1446
1444
|
self.decoder = UdopStack(decoder_config)
|
|
1447
1445
|
|
|
@@ -1476,7 +1474,7 @@ class UdopModel(UdopPreTrainedModel):
|
|
|
1476
1474
|
return_dict: Optional[bool] = None,
|
|
1477
1475
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1478
1476
|
**kwargs,
|
|
1479
|
-
) -> tuple
|
|
1477
|
+
) -> Union[tuple, Seq2SeqModelOutput]:
|
|
1480
1478
|
r"""
|
|
1481
1479
|
bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
|
|
1482
1480
|
Bounding boxes of each input sequence tokens. Selected in the range `[0,
|
|
@@ -1611,12 +1609,10 @@ class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin):
|
|
|
1611
1609
|
encoder_config = deepcopy(config)
|
|
1612
1610
|
encoder_config.is_decoder = False
|
|
1613
1611
|
encoder_config.use_cache = False
|
|
1614
|
-
encoder_config.tie_encoder_decoder = False
|
|
1615
1612
|
self.encoder = UdopStack(encoder_config)
|
|
1616
1613
|
|
|
1617
1614
|
decoder_config = deepcopy(config)
|
|
1618
1615
|
decoder_config.is_decoder = True
|
|
1619
|
-
decoder_config.tie_encoder_decoder = False
|
|
1620
1616
|
decoder_config.num_layers = config.num_decoder_layers
|
|
1621
1617
|
self.decoder = UdopStack(decoder_config)
|
|
1622
1618
|
|
|
@@ -1655,7 +1651,7 @@ class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin):
|
|
|
1655
1651
|
labels: Optional[Tensor] = None,
|
|
1656
1652
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1657
1653
|
**kwargs,
|
|
1658
|
-
) -> tuple
|
|
1654
|
+
) -> Union[tuple, Seq2SeqLMOutput]:
|
|
1659
1655
|
r"""
|
|
1660
1656
|
bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
|
|
1661
1657
|
Bounding boxes of each input sequence tokens. Selected in the range `[0,
|
|
@@ -94,7 +94,6 @@ class UMT5Config(PreTrainedConfig):
|
|
|
94
94
|
is_encoder_decoder=True,
|
|
95
95
|
use_cache=True,
|
|
96
96
|
tokenizer_class="T5Tokenizer",
|
|
97
|
-
tie_word_embeddings=True,
|
|
98
97
|
pad_token_id=0,
|
|
99
98
|
eos_token_id=1,
|
|
100
99
|
decoder_start_token_id=0,
|
|
@@ -133,10 +132,11 @@ class UMT5Config(PreTrainedConfig):
|
|
|
133
132
|
if feed_forward_proj == "gated-gelu":
|
|
134
133
|
self.dense_act_fn = "gelu_new"
|
|
135
134
|
|
|
135
|
+
# Force because official weights have False serialized, but we have to tie always
|
|
136
|
+
kwargs["tie_word_embeddings"] = True
|
|
136
137
|
super().__init__(
|
|
137
138
|
is_encoder_decoder=is_encoder_decoder,
|
|
138
139
|
tokenizer_class=tokenizer_class,
|
|
139
|
-
tie_word_embeddings=tie_word_embeddings,
|
|
140
140
|
pad_token_id=pad_token_id,
|
|
141
141
|
eos_token_id=eos_token_id,
|
|
142
142
|
decoder_start_token_id=decoder_start_token_id,
|
|
@@ -929,12 +929,10 @@ class UMT5Model(UMT5PreTrainedModel):
|
|
|
929
929
|
encoder_config = copy.deepcopy(config)
|
|
930
930
|
encoder_config.is_decoder = False
|
|
931
931
|
encoder_config.use_cache = False
|
|
932
|
-
encoder_config.tie_encoder_decoder = False
|
|
933
932
|
self.encoder = UMT5Stack(encoder_config)
|
|
934
933
|
|
|
935
934
|
decoder_config = copy.deepcopy(config)
|
|
936
935
|
decoder_config.is_decoder = True
|
|
937
|
-
decoder_config.tie_encoder_decoder = False
|
|
938
936
|
decoder_config.num_layers = config.num_decoder_layers
|
|
939
937
|
self.decoder = UMT5Stack(decoder_config)
|
|
940
938
|
|
|
@@ -1108,12 +1106,10 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin):
|
|
|
1108
1106
|
encoder_config = copy.deepcopy(config)
|
|
1109
1107
|
encoder_config.is_decoder = False
|
|
1110
1108
|
encoder_config.use_cache = False
|
|
1111
|
-
encoder_config.tie_encoder_decoder = False
|
|
1112
1109
|
self.encoder = UMT5Stack(encoder_config)
|
|
1113
1110
|
|
|
1114
1111
|
decoder_config = copy.deepcopy(config)
|
|
1115
1112
|
decoder_config.is_decoder = True
|
|
1116
|
-
decoder_config.tie_encoder_decoder = False
|
|
1117
1113
|
decoder_config.num_layers = config.num_decoder_layers
|
|
1118
1114
|
self.decoder = UMT5Stack(decoder_config)
|
|
1119
1115
|
|
|
@@ -1614,12 +1610,10 @@ class UMT5ForQuestionAnswering(UMT5PreTrainedModel):
|
|
|
1614
1610
|
encoder_config = copy.deepcopy(config)
|
|
1615
1611
|
encoder_config.is_decoder = False
|
|
1616
1612
|
encoder_config.use_cache = False
|
|
1617
|
-
encoder_config.tie_encoder_decoder = False
|
|
1618
1613
|
self.encoder = UMT5Stack(encoder_config)
|
|
1619
1614
|
|
|
1620
1615
|
decoder_config = copy.deepcopy(config)
|
|
1621
1616
|
decoder_config.is_decoder = True
|
|
1622
|
-
decoder_config.tie_encoder_decoder = False
|
|
1623
1617
|
decoder_config.num_layers = config.num_decoder_layers
|
|
1624
1618
|
self.decoder = UMT5Stack(decoder_config)
|
|
1625
1619
|
|
|
@@ -297,7 +297,7 @@ class VaultGemmaRotaryEmbedding(nn.Module):
|
|
|
297
297
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
298
298
|
|
|
299
299
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
300
|
-
self.original_inv_freq =
|
|
300
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
301
301
|
|
|
302
302
|
@staticmethod
|
|
303
303
|
def compute_default_rope_parameters(
|
|
@@ -154,8 +154,9 @@ class VideoLlama3ImageProcessor(BaseImageProcessor):
|
|
|
154
154
|
**kwargs,
|
|
155
155
|
) -> None:
|
|
156
156
|
super().__init__(**kwargs)
|
|
157
|
-
if size is not None
|
|
158
|
-
|
|
157
|
+
if size is not None:
|
|
158
|
+
if "shortest_edge" not in size or "longest_edge" not in size:
|
|
159
|
+
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
|
159
160
|
else:
|
|
160
161
|
size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
|
|
161
162
|
# backward compatibility: override size with min_pixels and max_pixels if they are provided
|
|
@@ -25,6 +25,7 @@ import torch
|
|
|
25
25
|
import torch.nn as nn
|
|
26
26
|
from torch.nn import LayerNorm
|
|
27
27
|
|
|
28
|
+
from ... import initialization as init
|
|
28
29
|
from ...activations import ACT2FN
|
|
29
30
|
from ...cache_utils import Cache
|
|
30
31
|
from ...generation import GenerationMixin
|
|
@@ -43,6 +44,8 @@ class VideoLlama3VisionRotaryEmbedding(nn.Module):
|
|
|
43
44
|
|
|
44
45
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
45
46
|
super().__init__()
|
|
47
|
+
self.dim = dim
|
|
48
|
+
self.theta = theta
|
|
46
49
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
|
47
50
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
48
51
|
|
|
@@ -380,6 +383,12 @@ class VideoLlama3PreTrainedModel(PreTrainedModel):
|
|
|
380
383
|
_can_compile_fullgraph = True
|
|
381
384
|
_supports_attention_backend = True
|
|
382
385
|
|
|
386
|
+
def _init_weights(self, module):
|
|
387
|
+
super()._init_weights(module)
|
|
388
|
+
if isinstance(module, VideoLlama3VisionRotaryEmbedding):
|
|
389
|
+
inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
|
|
390
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
391
|
+
|
|
383
392
|
|
|
384
393
|
class VideoLlama3VisionModel(VideoLlama3PreTrainedModel):
|
|
385
394
|
config: VideoLlama3VisionConfig
|
|
@@ -855,6 +864,7 @@ class VideoLlama3ForConditionalGeneration(VideoLlama3PreTrainedModel, Generation
|
|
|
855
864
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
856
865
|
video_merge_sizes: Optional[torch.LongTensor] = None,
|
|
857
866
|
video_compression_mask: Optional[torch.BoolTensor] = None,
|
|
867
|
+
is_first_iteration: Optional[bool] = False,
|
|
858
868
|
**kwargs,
|
|
859
869
|
):
|
|
860
870
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -874,10 +884,11 @@ class VideoLlama3ForConditionalGeneration(VideoLlama3PreTrainedModel, Generation
|
|
|
874
884
|
video_merge_sizes=video_merge_sizes,
|
|
875
885
|
video_compression_mask=video_compression_mask,
|
|
876
886
|
use_cache=use_cache,
|
|
887
|
+
is_first_iteration=is_first_iteration,
|
|
877
888
|
**kwargs,
|
|
878
889
|
)
|
|
879
890
|
|
|
880
|
-
if
|
|
891
|
+
if not is_first_iteration and use_cache:
|
|
881
892
|
model_inputs["pixel_values"] = None
|
|
882
893
|
model_inputs["pixel_values_videos"] = None
|
|
883
894
|
|
|
@@ -21,6 +21,7 @@ import torch.nn as nn
|
|
|
21
21
|
import torch.nn.functional as F
|
|
22
22
|
from torch.nn import LayerNorm
|
|
23
23
|
|
|
24
|
+
from ... import initialization as init
|
|
24
25
|
from ...cache_utils import Cache
|
|
25
26
|
from ...configuration_utils import PreTrainedConfig
|
|
26
27
|
from ...feature_extraction_utils import BatchFeature
|
|
@@ -433,6 +434,12 @@ class VideoLlama3PreTrainedModel(Qwen2VLPreTrainedModel):
|
|
|
433
434
|
config: VideoLlama3Config
|
|
434
435
|
_no_split_modules = ["VideoLlama3VisionEncoderLayer"]
|
|
435
436
|
|
|
437
|
+
def _init_weights(self, module):
|
|
438
|
+
PreTrainedModel._init_weights(self, module)
|
|
439
|
+
if isinstance(module, VideoLlama3VisionRotaryEmbedding):
|
|
440
|
+
inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
|
|
441
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
442
|
+
|
|
436
443
|
|
|
437
444
|
class VideoLlama3VisionModel(VideoLlama3PreTrainedModel):
|
|
438
445
|
config: VideoLlama3VisionConfig
|
|
@@ -842,6 +849,7 @@ class VideoLlama3ForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
|
|
842
849
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
843
850
|
video_merge_sizes: Optional[torch.LongTensor] = None,
|
|
844
851
|
video_compression_mask: Optional[torch.BoolTensor] = None,
|
|
852
|
+
is_first_iteration: Optional[bool] = False,
|
|
845
853
|
**kwargs,
|
|
846
854
|
):
|
|
847
855
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -861,10 +869,11 @@ class VideoLlama3ForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
|
|
861
869
|
video_merge_sizes=video_merge_sizes,
|
|
862
870
|
video_compression_mask=video_compression_mask,
|
|
863
871
|
use_cache=use_cache,
|
|
872
|
+
is_first_iteration=is_first_iteration,
|
|
864
873
|
**kwargs,
|
|
865
874
|
)
|
|
866
875
|
|
|
867
|
-
if
|
|
876
|
+
if not is_first_iteration and use_cache:
|
|
868
877
|
model_inputs["pixel_values"] = None
|
|
869
878
|
model_inputs["pixel_values_videos"] = None
|
|
870
879
|
|
|
@@ -599,6 +599,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
|
|
599
599
|
attention_mask=None,
|
|
600
600
|
cache_position=None,
|
|
601
601
|
logits_to_keep=None,
|
|
602
|
+
is_first_iteration=False,
|
|
602
603
|
**kwargs,
|
|
603
604
|
):
|
|
604
605
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -610,12 +611,15 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
|
|
610
611
|
attention_mask=attention_mask,
|
|
611
612
|
cache_position=cache_position,
|
|
612
613
|
logits_to_keep=logits_to_keep,
|
|
614
|
+
is_first_iteration=is_first_iteration,
|
|
613
615
|
**kwargs,
|
|
614
616
|
)
|
|
615
617
|
|
|
616
|
-
if
|
|
617
|
-
#
|
|
618
|
-
#
|
|
618
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
619
|
+
# Pixel values are used only in the first iteration if available
|
|
620
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
621
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
622
|
+
# iteration with a question and cached system prompt (continue generate from cache)
|
|
619
623
|
model_inputs["pixel_values_images"] = pixel_values_images
|
|
620
624
|
model_inputs["pixel_values_videos"] = pixel_values_videos
|
|
621
625
|
|
|
@@ -115,7 +115,7 @@ class ViltConfig(PreTrainedConfig):
|
|
|
115
115
|
num_channels=3,
|
|
116
116
|
qkv_bias=True,
|
|
117
117
|
max_image_length=-1,
|
|
118
|
-
tie_word_embeddings=
|
|
118
|
+
tie_word_embeddings=True,
|
|
119
119
|
num_images=-1,
|
|
120
120
|
**kwargs,
|
|
121
121
|
):
|
|
@@ -142,7 +142,7 @@ class ViltConfig(PreTrainedConfig):
|
|
|
142
142
|
self.qkv_bias = qkv_bias
|
|
143
143
|
self.max_image_length = max_image_length
|
|
144
144
|
self.num_images = num_images
|
|
145
|
-
self.
|
|
145
|
+
self.tie_word_embeddings = True # force it
|
|
146
146
|
|
|
147
147
|
|
|
148
148
|
__all__ = ["ViltConfig"]
|
|
@@ -23,6 +23,7 @@ import torch
|
|
|
23
23
|
from torch import nn
|
|
24
24
|
from torch.nn import CrossEntropyLoss
|
|
25
25
|
|
|
26
|
+
from ... import initialization as init
|
|
26
27
|
from ...activations import ACT2FN
|
|
27
28
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
28
29
|
from ...modeling_outputs import (
|
|
@@ -516,6 +517,12 @@ class ViltPreTrainedModel(PreTrainedModel):
|
|
|
516
517
|
supports_gradient_checkpointing = True
|
|
517
518
|
_no_split_modules = ["ViltEmbeddings", "ViltSelfAttention"]
|
|
518
519
|
|
|
520
|
+
def _init_weights(self, module):
|
|
521
|
+
super()._init_weights(module)
|
|
522
|
+
if isinstance(module, TextEmbeddings):
|
|
523
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
524
|
+
init.zeros_(module.token_type_ids)
|
|
525
|
+
|
|
519
526
|
|
|
520
527
|
@auto_docstring
|
|
521
528
|
class ViltModel(ViltPreTrainedModel):
|
|
@@ -415,6 +415,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
|
|
415
415
|
attention_mask=None,
|
|
416
416
|
cache_position=None,
|
|
417
417
|
logits_to_keep=None,
|
|
418
|
+
is_first_iteration=False,
|
|
418
419
|
**kwargs,
|
|
419
420
|
):
|
|
420
421
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -426,12 +427,15 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
|
|
426
427
|
attention_mask=attention_mask,
|
|
427
428
|
cache_position=cache_position,
|
|
428
429
|
logits_to_keep=logits_to_keep,
|
|
430
|
+
is_first_iteration=is_first_iteration,
|
|
429
431
|
**kwargs,
|
|
430
432
|
)
|
|
431
433
|
|
|
432
|
-
if
|
|
433
|
-
#
|
|
434
|
-
#
|
|
434
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
435
|
+
# Pixel values are used only in the first iteration if available
|
|
436
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
437
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
438
|
+
# iteration with a question and cached system prompt (continue generate from cache)
|
|
435
439
|
model_inputs["pixel_values"] = pixel_values
|
|
436
440
|
|
|
437
441
|
return model_inputs
|
|
@@ -473,6 +473,8 @@ class VisualBertPreTrainedModel(PreTrainedModel):
|
|
|
473
473
|
init.ones_(module.weight)
|
|
474
474
|
elif isinstance(module, VisualBertLMPredictionHead):
|
|
475
475
|
init.zeros_(module.bias)
|
|
476
|
+
elif isinstance(module, VisualBertEmbeddings):
|
|
477
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
476
478
|
|
|
477
479
|
|
|
478
480
|
@dataclass
|
|
@@ -36,7 +36,7 @@ class VitMatteConfig(PreTrainedConfig):
|
|
|
36
36
|
documentation from [`PreTrainedConfig`] for more information.
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
|
-
backbone_config (`PreTrainedConfig
|
|
39
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `VitDetConfig()`):
|
|
40
40
|
The configuration of the backbone model.
|
|
41
41
|
backbone (`str`, *optional*):
|
|
42
42
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -152,7 +152,6 @@ class VitMatteImageProcessorFast(BaseImageProcessorFast):
|
|
|
152
152
|
processed_images_grouped[shape] = stacked_images
|
|
153
153
|
|
|
154
154
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
155
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
156
155
|
|
|
157
156
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
158
157
|
|
|
@@ -65,6 +65,10 @@ class VitMattePreTrainedModel(PreTrainedModel):
|
|
|
65
65
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
66
66
|
if module.bias is not None:
|
|
67
67
|
init.zeros_(module.bias)
|
|
68
|
+
if getattr(module, "running_mean", None) is not None:
|
|
69
|
+
init.zeros_(module.running_mean)
|
|
70
|
+
init.ones_(module.running_var)
|
|
71
|
+
init.zeros_(module.num_batches_tracked)
|
|
68
72
|
|
|
69
73
|
|
|
70
74
|
class VitMatteBasicConv3x3(nn.Module):
|
|
@@ -36,7 +36,7 @@ class VitPoseConfig(PreTrainedConfig):
|
|
|
36
36
|
documentation from [`PreTrainedConfig`] for more information.
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
|
-
backbone_config (`PreTrainedConfig
|
|
39
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `VitPoseBackboneConfig()`):
|
|
40
40
|
The configuration of the backbone model. Currently, only `backbone_config` with `vitpose_backbone` as `model_type` is supported.
|
|
41
41
|
backbone (`str`, *optional*):
|
|
42
42
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -156,7 +156,6 @@ class VitPoseImageProcessorFast(BaseImageProcessorFast):
|
|
|
156
156
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
157
157
|
|
|
158
158
|
# Stack into batch tensor
|
|
159
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
160
159
|
|
|
161
160
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
162
161
|
|
|
@@ -505,11 +505,11 @@ class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
|
|
|
505
505
|
# Overwritten -- we should not pass input_features when we are in cached decoding stage
|
|
506
506
|
|
|
507
507
|
input_features = kwargs.pop("input_features", None)
|
|
508
|
-
|
|
508
|
+
is_first_iteration = kwargs.get("is_first_iteration", False)
|
|
509
509
|
|
|
510
510
|
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
|
511
511
|
|
|
512
|
-
if
|
|
512
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
513
513
|
# input_features should only be passed when we are not in cached decoding stage
|
|
514
514
|
model_inputs["input_features"] = input_features
|
|
515
515
|
|
|
@@ -267,11 +267,11 @@ class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
|
|
|
267
267
|
# Overwritten -- we should not pass input_features when we are in cached decoding stage
|
|
268
268
|
|
|
269
269
|
input_features = kwargs.pop("input_features", None)
|
|
270
|
-
|
|
270
|
+
is_first_iteration = kwargs.get("is_first_iteration", False)
|
|
271
271
|
|
|
272
272
|
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
|
273
273
|
|
|
274
|
-
if
|
|
274
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
275
275
|
# input_features should only be passed when we are not in cached decoding stage
|
|
276
276
|
model_inputs["input_features"] = input_features
|
|
277
277
|
|
|
@@ -74,18 +74,17 @@ class Wav2Vec2BertRelPositionalEmbedding(nn.Module):
|
|
|
74
74
|
super().__init__()
|
|
75
75
|
self.max_len = config.max_source_positions
|
|
76
76
|
self.d_model = config.hidden_size
|
|
77
|
-
self.pe =
|
|
78
|
-
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
|
|
77
|
+
self.register_buffer("pe", self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)), persistent=False)
|
|
79
78
|
|
|
80
|
-
def extend_pe(self, x):
|
|
79
|
+
def extend_pe(self, x, pe=None):
|
|
81
80
|
# Reset the positional encodings
|
|
82
|
-
if
|
|
81
|
+
if pe is not None:
|
|
83
82
|
# self.pe contains both positive and negative parts
|
|
84
83
|
# the length of self.pe is 2 * input_len - 1
|
|
85
|
-
if
|
|
86
|
-
if
|
|
87
|
-
|
|
88
|
-
return
|
|
84
|
+
if pe.size(1) >= x.size(1) * 2 - 1:
|
|
85
|
+
if pe.dtype != x.dtype or pe.device != x.device:
|
|
86
|
+
pe = pe.to(dtype=x.dtype, device=x.device)
|
|
87
|
+
return pe
|
|
89
88
|
# Suppose `i` is the position of query vector and `j` is the
|
|
90
89
|
# position of key vector. We use positive relative positions when keys
|
|
91
90
|
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
|
@@ -106,10 +105,10 @@ class Wav2Vec2BertRelPositionalEmbedding(nn.Module):
|
|
|
106
105
|
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
|
107
106
|
pe_negative = pe_negative[1:].unsqueeze(0)
|
|
108
107
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
|
109
|
-
|
|
108
|
+
return pe.to(device=x.device, dtype=x.dtype)
|
|
110
109
|
|
|
111
110
|
def forward(self, hidden_states: torch.Tensor):
|
|
112
|
-
self.extend_pe(hidden_states)
|
|
111
|
+
self.pe = self.extend_pe(hidden_states, self.pe)
|
|
113
112
|
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
|
|
114
113
|
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
|
|
115
114
|
relative_position_embeddings = self.pe[:, start_idx:end_idx]
|
|
@@ -749,6 +748,13 @@ class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
|
|
|
749
748
|
init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1))
|
|
750
749
|
elif isinstance(module, AMSoftmaxLoss): # noqa: F821
|
|
751
750
|
init.normal_(module.weight)
|
|
751
|
+
elif isinstance(module, Wav2Vec2BertRotaryPositionalEmbedding):
|
|
752
|
+
dim = self.config.hidden_size // self.config.num_attention_heads
|
|
753
|
+
base = self.config.rotary_embedding_base
|
|
754
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
|
755
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
756
|
+
elif isinstance(module, Wav2Vec2BertRelPositionalEmbedding):
|
|
757
|
+
init.copy_(module.pe, module.extend_pe(torch.tensor(0.0).expand(1, module.max_len)))
|
|
752
758
|
|
|
753
759
|
# Ignore copy
|
|
754
760
|
def _get_feat_extract_output_lengths(
|
|
@@ -621,6 +621,13 @@ class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
|
|
|
621
621
|
init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1))
|
|
622
622
|
elif isinstance(module, AMSoftmaxLoss): # noqa: F821
|
|
623
623
|
init.normal_(module.weight)
|
|
624
|
+
elif isinstance(module, Wav2Vec2BertRotaryPositionalEmbedding):
|
|
625
|
+
dim = self.config.hidden_size // self.config.num_attention_heads
|
|
626
|
+
base = self.config.rotary_embedding_base
|
|
627
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
|
628
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
629
|
+
elif isinstance(module, Wav2Vec2BertRelPositionalEmbedding):
|
|
630
|
+
init.copy_(module.pe, module.extend_pe(torch.tensor(0.0).expand(1, module.max_len)))
|
|
624
631
|
|
|
625
632
|
# Ignore copy
|
|
626
633
|
def _get_feat_extract_output_lengths(
|
|
@@ -164,18 +164,17 @@ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
|
|
|
164
164
|
super().__init__()
|
|
165
165
|
self.max_len = config.max_source_positions
|
|
166
166
|
self.d_model = config.hidden_size
|
|
167
|
-
self.pe =
|
|
168
|
-
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
|
|
167
|
+
self.register_buffer("pe", self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)), persistent=False)
|
|
169
168
|
|
|
170
|
-
def extend_pe(self, x):
|
|
169
|
+
def extend_pe(self, x, pe=None):
|
|
171
170
|
# Reset the positional encodings
|
|
172
|
-
if
|
|
171
|
+
if pe is not None:
|
|
173
172
|
# self.pe contains both positive and negative parts
|
|
174
173
|
# the length of self.pe is 2 * input_len - 1
|
|
175
|
-
if
|
|
176
|
-
if
|
|
177
|
-
|
|
178
|
-
return
|
|
174
|
+
if pe.size(1) >= x.size(1) * 2 - 1:
|
|
175
|
+
if pe.dtype != x.dtype or pe.device != x.device:
|
|
176
|
+
pe = pe.to(dtype=x.dtype, device=x.device)
|
|
177
|
+
return pe
|
|
179
178
|
# Suppose `i` is the position of query vector and `j` is the
|
|
180
179
|
# position of key vector. We use positive relative positions when keys
|
|
181
180
|
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
|
@@ -196,10 +195,10 @@ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
|
|
|
196
195
|
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
|
197
196
|
pe_negative = pe_negative[1:].unsqueeze(0)
|
|
198
197
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
|
199
|
-
|
|
198
|
+
return pe.to(device=x.device, dtype=x.dtype)
|
|
200
199
|
|
|
201
200
|
def forward(self, hidden_states: torch.Tensor):
|
|
202
|
-
self.extend_pe(hidden_states)
|
|
201
|
+
self.pe = self.extend_pe(hidden_states, self.pe)
|
|
203
202
|
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
|
|
204
203
|
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
|
|
205
204
|
relative_position_embeddings = self.pe[:, start_idx:end_idx]
|
|
@@ -885,15 +884,26 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
|
|
|
885
884
|
|
|
886
885
|
if module.bias is not None:
|
|
887
886
|
init.zeros_(module.bias)
|
|
888
|
-
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
|
887
|
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
|
|
889
888
|
init.zeros_(module.bias)
|
|
890
889
|
init.ones_(module.weight)
|
|
890
|
+
if getattr(module, "running_mean", None) is not None:
|
|
891
|
+
init.zeros_(module.running_mean)
|
|
892
|
+
init.ones_(module.running_var)
|
|
893
|
+
init.zeros_(module.num_batches_tracked)
|
|
891
894
|
elif isinstance(module, nn.Conv1d):
|
|
892
895
|
init.kaiming_normal_(module.weight)
|
|
893
896
|
|
|
894
897
|
if module.bias is not None:
|
|
895
898
|
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
|
896
899
|
init.uniform_(module.bias, a=-k, b=k)
|
|
900
|
+
elif isinstance(module, Wav2Vec2ConformerRotaryPositionalEmbedding):
|
|
901
|
+
dim = self.config.hidden_size // self.config.num_attention_heads
|
|
902
|
+
base = self.config.rotary_embedding_base
|
|
903
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
|
904
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
905
|
+
elif isinstance(module, Wav2Vec2ConformerRelPositionalEmbedding):
|
|
906
|
+
init.copy_(module.pe, module.extend_pe(torch.tensor(0.0).expand(1, module.max_len)))
|
|
897
907
|
|
|
898
908
|
def _get_feat_extract_output_lengths(
|
|
899
909
|
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
|