transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -35,7 +35,7 @@ from transformers.activations import ACT2FN
|
|
|
35
35
|
from ... import initialization as init
|
|
36
36
|
from ...cache_utils import Cache
|
|
37
37
|
from ...generation import GenerationMixin
|
|
38
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
38
|
+
from ...integrations import lazy_load_kernel, use_kernel_forward_from_hub, use_kernelized_func
|
|
39
39
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
|
40
40
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
41
41
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -44,22 +44,9 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
|
44
44
|
from ...processing_utils import Unpack
|
|
45
45
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
46
46
|
from ...utils.generic import maybe_autocast
|
|
47
|
-
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
|
48
47
|
from .configuration_bamba import BambaConfig
|
|
49
48
|
|
|
50
49
|
|
|
51
|
-
if is_mamba_2_ssm_available():
|
|
52
|
-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
|
53
|
-
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
|
54
|
-
else:
|
|
55
|
-
selective_state_update = None
|
|
56
|
-
|
|
57
|
-
if is_causal_conv1d_available():
|
|
58
|
-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
59
|
-
else:
|
|
60
|
-
causal_conv1d_update, causal_conv1d_fn = None, None
|
|
61
|
-
|
|
62
|
-
|
|
63
50
|
logger = logging.get_logger(__name__)
|
|
64
51
|
|
|
65
52
|
|
|
@@ -212,7 +199,7 @@ class BambaRotaryEmbedding(nn.Module):
|
|
|
212
199
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
213
200
|
|
|
214
201
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
215
|
-
self.original_inv_freq =
|
|
202
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
216
203
|
|
|
217
204
|
@staticmethod
|
|
218
205
|
def compute_default_rope_parameters(
|
|
@@ -501,9 +488,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
|
|
|
501
488
|
return hidden_states
|
|
502
489
|
|
|
503
490
|
|
|
504
|
-
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
|
505
|
-
|
|
506
|
-
|
|
507
491
|
# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
|
|
508
492
|
class BambaMixer(nn.Module):
|
|
509
493
|
"""
|
|
@@ -575,6 +559,20 @@ class BambaMixer(nn.Module):
|
|
|
575
559
|
|
|
576
560
|
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
|
|
577
561
|
|
|
562
|
+
global causal_conv1d_update, causal_conv1d_fn
|
|
563
|
+
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
564
|
+
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
|
|
565
|
+
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
|
|
566
|
+
|
|
567
|
+
global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
|
568
|
+
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
|
569
|
+
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
|
|
570
|
+
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
|
|
571
|
+
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
|
|
572
|
+
|
|
573
|
+
global is_fast_path_available
|
|
574
|
+
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
|
575
|
+
|
|
578
576
|
if not is_fast_path_available:
|
|
579
577
|
logger.warning_once(
|
|
580
578
|
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
|
|
@@ -1489,6 +1487,7 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
|
|
|
1489
1487
|
cache_position=None,
|
|
1490
1488
|
position_ids=None,
|
|
1491
1489
|
use_cache=True,
|
|
1490
|
+
is_first_iteration=False,
|
|
1492
1491
|
**kwargs,
|
|
1493
1492
|
):
|
|
1494
1493
|
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
|
|
@@ -1521,7 +1520,7 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
|
|
|
1521
1520
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
1522
1521
|
|
|
1523
1522
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
1524
|
-
if inputs_embeds is not None and
|
|
1523
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
1525
1524
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
1526
1525
|
else:
|
|
1527
1526
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
@@ -43,6 +43,7 @@ from transformers.models.mamba2.modeling_mamba2 import (
|
|
|
43
43
|
)
|
|
44
44
|
|
|
45
45
|
from ... import initialization as init
|
|
46
|
+
from ...integrations import lazy_load_kernel
|
|
46
47
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
|
47
48
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
48
49
|
from ...modeling_utils import PreTrainedModel
|
|
@@ -52,24 +53,9 @@ from ...utils import (
|
|
|
52
53
|
can_return_tuple,
|
|
53
54
|
logging,
|
|
54
55
|
)
|
|
55
|
-
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
|
56
56
|
from .configuration_bamba import BambaConfig
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
if is_mamba_2_ssm_available():
|
|
60
|
-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
|
61
|
-
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
|
62
|
-
else:
|
|
63
|
-
selective_state_update = None
|
|
64
|
-
|
|
65
|
-
if is_causal_conv1d_available():
|
|
66
|
-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
67
|
-
else:
|
|
68
|
-
causal_conv1d_update, causal_conv1d_fn = None, None
|
|
69
|
-
|
|
70
|
-
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
|
71
|
-
|
|
72
|
-
|
|
73
59
|
logger = logging.get_logger(__name__)
|
|
74
60
|
|
|
75
61
|
|
|
@@ -276,6 +262,20 @@ class BambaMixer(nn.Module):
|
|
|
276
262
|
|
|
277
263
|
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
|
|
278
264
|
|
|
265
|
+
global causal_conv1d_update, causal_conv1d_fn
|
|
266
|
+
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
267
|
+
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
|
|
268
|
+
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
|
|
269
|
+
|
|
270
|
+
global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
|
271
|
+
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
|
272
|
+
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
|
|
273
|
+
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
|
|
274
|
+
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
|
|
275
|
+
|
|
276
|
+
global is_fast_path_available
|
|
277
|
+
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
|
278
|
+
|
|
279
279
|
if not is_fast_path_available:
|
|
280
280
|
logger.warning_once(
|
|
281
281
|
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
|
|
@@ -1151,6 +1151,7 @@ class BambaForCausalLM(LlamaForCausalLM):
|
|
|
1151
1151
|
cache_position=None,
|
|
1152
1152
|
position_ids=None,
|
|
1153
1153
|
use_cache=True,
|
|
1154
|
+
is_first_iteration=False,
|
|
1154
1155
|
**kwargs,
|
|
1155
1156
|
):
|
|
1156
1157
|
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
|
|
@@ -1183,7 +1184,7 @@ class BambaForCausalLM(LlamaForCausalLM):
|
|
|
1183
1184
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
1184
1185
|
|
|
1185
1186
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
1186
|
-
if inputs_embeds is not None and
|
|
1187
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
1187
1188
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
1188
1189
|
else:
|
|
1189
1190
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
@@ -23,6 +23,7 @@ import torch
|
|
|
23
23
|
from torch import nn
|
|
24
24
|
from torch.nn import functional as F
|
|
25
25
|
|
|
26
|
+
from ... import initialization as init
|
|
26
27
|
from ...cache_utils import Cache, DynamicCache
|
|
27
28
|
from ...generation import GenerationMixin
|
|
28
29
|
from ...generation.logits_process import (
|
|
@@ -349,6 +350,14 @@ class BarkPreTrainedModel(PreTrainedModel):
|
|
|
349
350
|
|
|
350
351
|
return super().device
|
|
351
352
|
|
|
353
|
+
def _init_weights(self, module):
|
|
354
|
+
super()._init_weights(module)
|
|
355
|
+
if isinstance(module, BarkSelfAttention):
|
|
356
|
+
if module.is_causal:
|
|
357
|
+
block_size = module.config.block_size
|
|
358
|
+
bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size)
|
|
359
|
+
init.copy_(module.bias, bias)
|
|
360
|
+
|
|
352
361
|
|
|
353
362
|
# GPT2-like autoregressive model
|
|
354
363
|
class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
|
|
@@ -23,6 +23,7 @@ import torch
|
|
|
23
23
|
from torch import nn
|
|
24
24
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
25
25
|
|
|
26
|
+
from ... import initialization as init
|
|
26
27
|
from ...activations import ACT2FN
|
|
27
28
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
28
29
|
from ...generation import GenerationMixin
|
|
@@ -476,6 +477,11 @@ class BartPreTrainedModel(PreTrainedModel):
|
|
|
476
477
|
|
|
477
478
|
_can_compile_fullgraph = True
|
|
478
479
|
|
|
480
|
+
def _init_weights(self, module):
|
|
481
|
+
super()._init_weights(module)
|
|
482
|
+
if isinstance(module, BartForConditionalGeneration):
|
|
483
|
+
init.zeros_(module.final_logits_bias)
|
|
484
|
+
|
|
479
485
|
@property
|
|
480
486
|
def dummy_inputs(self):
|
|
481
487
|
pad_token = self.config.pad_token_id
|
|
@@ -1463,6 +1469,7 @@ class BartDecoderWrapper(BartPreTrainedModel):
|
|
|
1463
1469
|
def __init__(self, config):
|
|
1464
1470
|
super().__init__(config)
|
|
1465
1471
|
self.decoder = BartDecoder(config)
|
|
1472
|
+
self.post_init()
|
|
1466
1473
|
|
|
1467
1474
|
def forward(self, *args, **kwargs):
|
|
1468
1475
|
return self.decoder(*args, **kwargs)
|
|
@@ -163,7 +163,6 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
|
|
|
163
163
|
processed_images_grouped[shape] = stacked_images
|
|
164
164
|
|
|
165
165
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
166
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
167
166
|
|
|
168
167
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
169
168
|
|
|
@@ -569,6 +569,9 @@ class BertPreTrainedModel(PreTrainedModel):
|
|
|
569
569
|
super()._init_weights(module)
|
|
570
570
|
if isinstance(module, BertLMPredictionHead):
|
|
571
571
|
init.zeros_(module.bias)
|
|
572
|
+
elif isinstance(module, BertEmbeddings):
|
|
573
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
574
|
+
init.zeros_(module.token_type_ids)
|
|
572
575
|
|
|
573
576
|
|
|
574
577
|
@dataclass
|
|
@@ -463,6 +463,8 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
|
|
|
463
463
|
super()._init_weights(module)
|
|
464
464
|
if isinstance(module, BertGenerationOnlyLMHead):
|
|
465
465
|
init.zeros_(module.bias)
|
|
466
|
+
elif isinstance(module, BertGenerationEmbeddings):
|
|
467
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
466
468
|
|
|
467
469
|
|
|
468
470
|
@auto_docstring(
|
|
@@ -1521,6 +1521,9 @@ class BigBirdPreTrainedModel(PreTrainedModel):
|
|
|
1521
1521
|
super()._init_weights(module)
|
|
1522
1522
|
if isinstance(module, BigBirdLMPredictionHead):
|
|
1523
1523
|
init.zeros_(module.bias)
|
|
1524
|
+
elif isinstance(module, BigBirdEmbeddings):
|
|
1525
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
1526
|
+
init.zeros_(module.token_type_ids)
|
|
1524
1527
|
|
|
1525
1528
|
|
|
1526
1529
|
@dataclass
|
|
@@ -23,6 +23,7 @@ import torch
|
|
|
23
23
|
from torch import nn
|
|
24
24
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
25
25
|
|
|
26
|
+
from ... import initialization as init
|
|
26
27
|
from ...activations import ACT2FN
|
|
27
28
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
28
29
|
from ...generation import GenerationMixin
|
|
@@ -1536,6 +1537,11 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
|
|
|
1536
1537
|
_skip_keys_device_placement = "past_key_values"
|
|
1537
1538
|
_can_compile_fullgraph = True
|
|
1538
1539
|
|
|
1540
|
+
def _init_weights(self, module):
|
|
1541
|
+
super()._init_weights(module)
|
|
1542
|
+
if isinstance(module, BigBirdPegasusForConditionalGeneration):
|
|
1543
|
+
init.zeros_(module.final_logits_bias)
|
|
1544
|
+
|
|
1539
1545
|
@property
|
|
1540
1546
|
def dummy_inputs(self):
|
|
1541
1547
|
pad_token = self.config.pad_token_id
|
|
@@ -2582,6 +2588,7 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
|
|
|
2582
2588
|
def __init__(self, config):
|
|
2583
2589
|
super().__init__(config)
|
|
2584
2590
|
self.decoder = BigBirdPegasusDecoder(config)
|
|
2591
|
+
self.post_init()
|
|
2585
2592
|
|
|
2586
2593
|
def forward(self, *args, **kwargs):
|
|
2587
2594
|
return self.decoder(*args, **kwargs)
|
|
@@ -84,7 +84,7 @@ class WeightStandardizedConv2d(nn.Conv2d):
|
|
|
84
84
|
"""Conv2d with Weight Standardization. Used for ViT Hybrid model.
|
|
85
85
|
|
|
86
86
|
Paper: [Micro-Batch Training with Batch-Channel Normalization and Weight
|
|
87
|
-
Standardization](https://huggingface.co/papers/1903.
|
|
87
|
+
Standardization](https://huggingface.co/papers/1903.10520)
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
90
|
def __init__(
|
|
@@ -643,6 +643,10 @@ class BitPreTrainedModel(PreTrainedModel):
|
|
|
643
643
|
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
644
644
|
init.constant_(module.weight, 1)
|
|
645
645
|
init.constant_(module.bias, 0)
|
|
646
|
+
if getattr(module, "running_mean", None) is not None:
|
|
647
|
+
init.zeros_(module.running_mean)
|
|
648
|
+
init.ones_(module.running_var)
|
|
649
|
+
init.zeros_(module.num_batches_tracked)
|
|
646
650
|
|
|
647
651
|
|
|
648
652
|
@auto_docstring
|
|
@@ -287,7 +287,7 @@ class BitNetRotaryEmbedding(nn.Module):
|
|
|
287
287
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
288
288
|
|
|
289
289
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
290
|
-
self.original_inv_freq =
|
|
290
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
291
291
|
|
|
292
292
|
@staticmethod
|
|
293
293
|
def compute_default_rope_parameters(
|
|
@@ -24,6 +24,7 @@ import torch
|
|
|
24
24
|
from torch import nn
|
|
25
25
|
from torch.nn import CrossEntropyLoss
|
|
26
26
|
|
|
27
|
+
from ... import initialization as init
|
|
27
28
|
from ...activations import ACT2FN
|
|
28
29
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
29
30
|
from ...generation import GenerationMixin
|
|
@@ -437,6 +438,11 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
|
|
|
437
438
|
_supports_flex_attn = True
|
|
438
439
|
_can_compile_fullgraph = True
|
|
439
440
|
|
|
441
|
+
def _init_weights(self, module):
|
|
442
|
+
super()._init_weights(module)
|
|
443
|
+
if isinstance(module, BlenderbotForConditionalGeneration):
|
|
444
|
+
init.zeros_(module.final_logits_bias)
|
|
445
|
+
|
|
440
446
|
@property
|
|
441
447
|
def dummy_inputs(self):
|
|
442
448
|
pad_token = self.config.pad_token_id
|
|
@@ -1156,6 +1162,7 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
|
|
|
1156
1162
|
def __init__(self, config):
|
|
1157
1163
|
super().__init__(config)
|
|
1158
1164
|
self.decoder = BlenderbotDecoder(config)
|
|
1165
|
+
self.post_init()
|
|
1159
1166
|
|
|
1160
1167
|
def forward(self, *args, **kwargs):
|
|
1161
1168
|
return self.decoder(*args, **kwargs)
|
|
@@ -160,13 +160,6 @@ class BlenderbotTokenizer(TokenizersBackend):
|
|
|
160
160
|
|
|
161
161
|
self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
|
|
162
162
|
self._tokenizer.decoder = decoders.ByteLevel()
|
|
163
|
-
self._tokenizer.post_processor = processors.RobertaProcessing(
|
|
164
|
-
sep=(str(eos_token), self._vocab.get(str(eos_token), 2)),
|
|
165
|
-
cls=(str(bos_token), self._vocab.get(str(bos_token), 0)),
|
|
166
|
-
add_prefix_space=add_prefix_space,
|
|
167
|
-
trim_offsets=True,
|
|
168
|
-
)
|
|
169
|
-
|
|
170
163
|
super().__init__(
|
|
171
164
|
bos_token=bos_token,
|
|
172
165
|
eos_token=eos_token,
|
|
@@ -178,6 +171,12 @@ class BlenderbotTokenizer(TokenizersBackend):
|
|
|
178
171
|
add_prefix_space=add_prefix_space,
|
|
179
172
|
**kwargs,
|
|
180
173
|
)
|
|
174
|
+
self._tokenizer.post_processor = processors.RobertaProcessing(
|
|
175
|
+
sep=(str(eos_token), self.eos_token_id),
|
|
176
|
+
cls=(str(bos_token), self.bos_token_id),
|
|
177
|
+
add_prefix_space=add_prefix_space,
|
|
178
|
+
trim_offsets=True,
|
|
179
|
+
)
|
|
181
180
|
|
|
182
181
|
|
|
183
182
|
__all__ = ["BlenderbotTokenizer"]
|
|
@@ -22,6 +22,7 @@ import torch
|
|
|
22
22
|
from torch import nn
|
|
23
23
|
from torch.nn import CrossEntropyLoss
|
|
24
24
|
|
|
25
|
+
from ... import initialization as init
|
|
25
26
|
from ...activations import ACT2FN
|
|
26
27
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
27
28
|
from ...generation import GenerationMixin
|
|
@@ -430,6 +431,11 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
|
|
|
430
431
|
_supports_flex_attn = True
|
|
431
432
|
_can_compile_fullgraph = True
|
|
432
433
|
|
|
434
|
+
def _init_weights(self, module):
|
|
435
|
+
super()._init_weights(module)
|
|
436
|
+
if isinstance(module, BlenderbotSmallForConditionalGeneration):
|
|
437
|
+
init.zeros_(module.final_logits_bias)
|
|
438
|
+
|
|
433
439
|
@property
|
|
434
440
|
def dummy_inputs(self):
|
|
435
441
|
pad_token = self.config.pad_token_id
|
|
@@ -1116,6 +1122,7 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
|
|
|
1116
1122
|
def __init__(self, config):
|
|
1117
1123
|
super().__init__(config)
|
|
1118
1124
|
self.decoder = BlenderbotSmallDecoder(config)
|
|
1125
|
+
self.post_init()
|
|
1119
1126
|
|
|
1120
1127
|
def forward(self, *args, **kwargs):
|
|
1121
1128
|
return self.decoder(*args, **kwargs)
|
|
@@ -430,6 +430,8 @@ class BlipPreTrainedModel(PreTrainedModel):
|
|
|
430
430
|
std = self.config.vision_config.initializer_range
|
|
431
431
|
init.trunc_normal_(module.position_embedding, mean=0.0, std=std)
|
|
432
432
|
init.trunc_normal_(module.class_embedding, mean=0.0, std=std)
|
|
433
|
+
elif isinstance(module, BlipTextEmbeddings):
|
|
434
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
433
435
|
|
|
434
436
|
|
|
435
437
|
class BlipEncoder(nn.Module):
|
|
@@ -21,6 +21,7 @@ import torch
|
|
|
21
21
|
from torch import Tensor, device, nn
|
|
22
22
|
from torch.nn import CrossEntropyLoss
|
|
23
23
|
|
|
24
|
+
from ... import initialization as init
|
|
24
25
|
from ...activations import ACT2FN
|
|
25
26
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
26
27
|
from ...generation import GenerationMixin
|
|
@@ -504,6 +505,11 @@ class BlipTextPreTrainedModel(PreTrainedModel):
|
|
|
504
505
|
base_model_prefix = "bert"
|
|
505
506
|
_no_split_modules = []
|
|
506
507
|
|
|
508
|
+
def _init_weights(self, module):
|
|
509
|
+
super()._init_weights(module)
|
|
510
|
+
if isinstance(module, BlipTextEmbeddings):
|
|
511
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
512
|
+
|
|
507
513
|
|
|
508
514
|
# Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571
|
|
509
515
|
class BlipTextModel(BlipTextPreTrainedModel):
|
|
@@ -740,6 +746,8 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
|
|
|
740
746
|
self.cls = BlipTextOnlyMLMHead(config)
|
|
741
747
|
self.label_smoothing = config.label_smoothing
|
|
742
748
|
|
|
749
|
+
self.post_init()
|
|
750
|
+
|
|
743
751
|
def get_input_embeddings(self):
|
|
744
752
|
return self.bert.get_input_embeddings()
|
|
745
753
|
|
|
@@ -428,6 +428,8 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
|
|
428
428
|
),
|
|
429
429
|
):
|
|
430
430
|
init.zeros_(module.query_tokens)
|
|
431
|
+
elif isinstance(module, Blip2TextEmbeddings):
|
|
432
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
431
433
|
|
|
432
434
|
|
|
433
435
|
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
|
|
@@ -714,36 +714,21 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
|
|
|
714
714
|
inputs_embeds=None,
|
|
715
715
|
cache_position=None,
|
|
716
716
|
use_cache=True,
|
|
717
|
+
is_first_iteration=False,
|
|
717
718
|
**kwargs,
|
|
718
719
|
):
|
|
719
720
|
# Overwritten because of the fixed-shape attention mask creation
|
|
720
721
|
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
elif (
|
|
732
|
-
inputs_embeds is not None # Exception 1
|
|
733
|
-
or cache_position[-1] >= input_ids.shape[1] # Exception 3
|
|
734
|
-
):
|
|
735
|
-
input_ids = input_ids[:, -cache_position.shape[0] :]
|
|
736
|
-
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
|
737
|
-
input_ids = input_ids[:, cache_position]
|
|
738
|
-
|
|
739
|
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
740
|
-
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
|
|
741
|
-
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
|
742
|
-
else:
|
|
743
|
-
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
|
|
744
|
-
# input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in
|
|
745
|
-
# the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
|
746
|
-
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
|
722
|
+
model_inputs = super().prepare_inputs_for_generation(
|
|
723
|
+
input_ids,
|
|
724
|
+
past_key_values=past_key_values,
|
|
725
|
+
attention_mask=attention_mask,
|
|
726
|
+
inputs_embeds=inputs_embeds,
|
|
727
|
+
cache_position=cache_position,
|
|
728
|
+
use_cache=use_cache,
|
|
729
|
+
is_first_iteration=is_first_iteration,
|
|
730
|
+
**kwargs,
|
|
731
|
+
)
|
|
747
732
|
|
|
748
733
|
# This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor
|
|
749
734
|
# The only difference is the usage of 2D instead of 4D mask, but the shape will be static
|
|
@@ -753,24 +738,8 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
|
|
|
753
738
|
diff = target_length - seq_length
|
|
754
739
|
|
|
755
740
|
new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype)
|
|
756
|
-
attention_mask = torch.cat(
|
|
757
|
-
|
|
758
|
-
dim=-1,
|
|
759
|
-
)
|
|
760
|
-
|
|
761
|
-
model_inputs.update(
|
|
762
|
-
{
|
|
763
|
-
"cache_position": cache_position,
|
|
764
|
-
"past_key_values": past_key_values,
|
|
765
|
-
"use_cache": use_cache,
|
|
766
|
-
"attention_mask": attention_mask,
|
|
767
|
-
}
|
|
768
|
-
)
|
|
769
|
-
|
|
770
|
-
# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
|
771
|
-
for key, value in kwargs.items():
|
|
772
|
-
if key not in model_inputs:
|
|
773
|
-
model_inputs[key] = value
|
|
741
|
+
attention_mask = torch.cat([attention_mask, new_attn_mask], dim=-1)
|
|
742
|
+
model_inputs["attention_mask"] = attention_mask
|
|
774
743
|
|
|
775
744
|
return model_inputs
|
|
776
745
|
|