transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -26,8 +26,9 @@ from torch import nn
|
|
|
26
26
|
from ... import initialization as init
|
|
27
27
|
from ...activations import ACT2FN
|
|
28
28
|
from ...cache_utils import Cache, DynamicCache
|
|
29
|
+
from ...configuration_utils import PreTrainedConfig
|
|
29
30
|
from ...generation import GenerationMixin
|
|
30
|
-
from ...
|
|
31
|
+
from ...masking_utils import create_masks_for_generate
|
|
31
32
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
32
33
|
from ...modeling_outputs import (
|
|
33
34
|
BaseModelOutput,
|
|
@@ -69,6 +70,104 @@ class GitVisionModelOutput(ModelOutput):
|
|
|
69
70
|
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
|
70
71
|
|
|
71
72
|
|
|
73
|
+
# Copied from transformers.models.gemma3.modeling_gemma3.token_type_ids_mask_function
|
|
74
|
+
def token_type_ids_mask_function(
|
|
75
|
+
token_type_ids: Optional[torch.Tensor],
|
|
76
|
+
image_group_ids: Optional[torch.Tensor],
|
|
77
|
+
) -> Optional[Callable]:
|
|
78
|
+
"""
|
|
79
|
+
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
|
80
|
+
not start and end indices.
|
|
81
|
+
"""
|
|
82
|
+
# Do not return an additional mask in this case
|
|
83
|
+
if token_type_ids is None:
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
|
87
|
+
# If it's 1 for both query and key/value, we are in an image block
|
|
88
|
+
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
|
|
89
|
+
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
|
|
90
|
+
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
|
|
91
|
+
safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
|
92
|
+
|
|
93
|
+
token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
|
|
94
|
+
token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
|
|
95
|
+
|
|
96
|
+
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
|
|
97
|
+
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
|
|
98
|
+
|
|
99
|
+
image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
|
|
100
|
+
image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
|
|
101
|
+
|
|
102
|
+
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
|
|
103
|
+
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
|
|
104
|
+
|
|
105
|
+
is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
|
|
106
|
+
same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
|
|
107
|
+
|
|
108
|
+
# This is bidirectional attention whenever we are dealing with image tokens
|
|
109
|
+
return is_image_block & same_image_block
|
|
110
|
+
|
|
111
|
+
return inner_mask
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# Copied from transformers.models.gemma3.modeling_gemma3.create_causal_mask_mapping
|
|
115
|
+
def create_causal_mask_mapping(
|
|
116
|
+
config: PreTrainedConfig,
|
|
117
|
+
input_embeds: torch.Tensor,
|
|
118
|
+
attention_mask: Optional[torch.Tensor],
|
|
119
|
+
cache_position: torch.Tensor,
|
|
120
|
+
past_key_values: Optional[Cache],
|
|
121
|
+
position_ids: Optional[torch.Tensor],
|
|
122
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
123
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
124
|
+
is_training: bool = False,
|
|
125
|
+
is_first_iteration: Optional[bool] = None,
|
|
126
|
+
**kwargs,
|
|
127
|
+
) -> dict:
|
|
128
|
+
"""
|
|
129
|
+
Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
|
|
130
|
+
for all kinds of forward passes. Gemma3 uses a bidirectional mask for images.
|
|
131
|
+
|
|
132
|
+
Uses `pixel_values` as an optional input to disambiguate edge cases.
|
|
133
|
+
"""
|
|
134
|
+
if is_training and token_type_ids is None:
|
|
135
|
+
raise ValueError("`token_type_ids` is required as a model input when training")
|
|
136
|
+
|
|
137
|
+
mask_kwargs = {
|
|
138
|
+
"config": config.get_text_config(),
|
|
139
|
+
"input_embeds": input_embeds,
|
|
140
|
+
"attention_mask": attention_mask,
|
|
141
|
+
"cache_position": cache_position,
|
|
142
|
+
"past_key_values": past_key_values,
|
|
143
|
+
"position_ids": position_ids,
|
|
144
|
+
}
|
|
145
|
+
# NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
|
|
146
|
+
# (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
|
|
147
|
+
# means). Determining prefill in that case requires checking data values, which is not compile-compatible.
|
|
148
|
+
is_first_iteration = (
|
|
149
|
+
is_first_iteration
|
|
150
|
+
if is_first_iteration is not None
|
|
151
|
+
else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
|
|
152
|
+
)
|
|
153
|
+
if token_type_ids is not None and is_first_iteration:
|
|
154
|
+
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
|
|
155
|
+
# undo the causal masking)
|
|
156
|
+
|
|
157
|
+
# First find where a new image block starts: 1 if image and previous not image
|
|
158
|
+
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
|
159
|
+
is_image = (token_type_ids == 1).to(cache_position.device)
|
|
160
|
+
is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
|
161
|
+
new_image_start = is_image & ~is_previous_image
|
|
162
|
+
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
|
163
|
+
image_group_ids = torch.where(is_image, image_group_ids, -1)
|
|
164
|
+
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
|
165
|
+
token_type_ids.to(cache_position.device), image_group_ids
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
return create_masks_for_generate(**mask_kwargs)
|
|
169
|
+
|
|
170
|
+
|
|
72
171
|
class GitEmbeddings(nn.Module):
|
|
73
172
|
"""Construct the embeddings from word and position embeddings."""
|
|
74
173
|
|
|
@@ -148,17 +247,15 @@ class GitSelfAttention(nn.Module):
|
|
|
148
247
|
hidden_states: torch.Tensor,
|
|
149
248
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
150
249
|
past_key_values: Optional[Cache] = None,
|
|
151
|
-
|
|
152
|
-
pixel_values_present: Optional[bool] = False,
|
|
250
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
153
251
|
) -> tuple[torch.Tensor]:
|
|
154
|
-
batch_size
|
|
252
|
+
batch_size = hidden_states.shape[0]
|
|
155
253
|
query_layer = (
|
|
156
254
|
self.query(hidden_states)
|
|
157
255
|
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
|
158
256
|
.transpose(1, 2)
|
|
159
257
|
)
|
|
160
258
|
|
|
161
|
-
cutoff = self.image_patch_tokens if pixel_values_present else 0
|
|
162
259
|
key_layer = (
|
|
163
260
|
self.key(hidden_states)
|
|
164
261
|
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
|
@@ -170,12 +267,9 @@ class GitSelfAttention(nn.Module):
|
|
|
170
267
|
.transpose(1, 2)
|
|
171
268
|
)
|
|
172
269
|
if past_key_values is not None:
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
key_layer[:, :, cutoff:, :], value_layer[:, :, cutoff:, :], self.layer_idx
|
|
270
|
+
key_layer, value_layer = past_key_values.update(
|
|
271
|
+
key_layer, value_layer, self.layer_idx, cache_kwargs={"cache_position": cache_position}
|
|
176
272
|
)
|
|
177
|
-
key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2)
|
|
178
|
-
value_layer = torch.cat([value_layer[:, :, :cutoff, :], value_layer_past], dim=2)
|
|
179
273
|
|
|
180
274
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
181
275
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
|
@@ -232,15 +326,14 @@ class GitAttention(nn.Module):
|
|
|
232
326
|
hidden_states: torch.Tensor,
|
|
233
327
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
234
328
|
past_key_values: Optional[Cache] = None,
|
|
329
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
235
330
|
output_attentions: Optional[bool] = False,
|
|
236
|
-
pixel_values_present: Optional[bool] = False,
|
|
237
331
|
) -> tuple[torch.Tensor]:
|
|
238
332
|
attn_output, self_attn_weights = self.self(
|
|
239
333
|
hidden_states,
|
|
240
334
|
attention_mask,
|
|
241
335
|
past_key_values,
|
|
242
|
-
|
|
243
|
-
pixel_values_present,
|
|
336
|
+
cache_position=cache_position,
|
|
244
337
|
)
|
|
245
338
|
attention_output = self.output(attn_output, hidden_states)
|
|
246
339
|
return attention_output, self_attn_weights
|
|
@@ -291,8 +384,8 @@ class GitLayer(GradientCheckpointingLayer):
|
|
|
291
384
|
hidden_states: torch.Tensor,
|
|
292
385
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
293
386
|
past_key_values: Optional[Cache] = None,
|
|
387
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
294
388
|
output_attentions: Optional[bool] = False,
|
|
295
|
-
pixel_values_present: Optional[bool] = False,
|
|
296
389
|
) -> tuple[torch.Tensor]:
|
|
297
390
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
|
298
391
|
attention_output, self_attention_weights = self.attention(
|
|
@@ -300,7 +393,7 @@ class GitLayer(GradientCheckpointingLayer):
|
|
|
300
393
|
attention_mask,
|
|
301
394
|
output_attentions=output_attentions,
|
|
302
395
|
past_key_values=past_key_values,
|
|
303
|
-
|
|
396
|
+
cache_position=cache_position,
|
|
304
397
|
)
|
|
305
398
|
|
|
306
399
|
layer_output = apply_chunking_to_forward(
|
|
@@ -329,8 +422,8 @@ class GitEncoder(nn.Module):
|
|
|
329
422
|
use_cache: Optional[bool] = None,
|
|
330
423
|
output_attentions: Optional[bool] = False,
|
|
331
424
|
output_hidden_states: Optional[bool] = False,
|
|
332
|
-
pixel_values_present: Optional[bool] = False,
|
|
333
425
|
return_dict: Optional[bool] = True,
|
|
426
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
334
427
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPast]:
|
|
335
428
|
if self.gradient_checkpointing and self.training:
|
|
336
429
|
if use_cache:
|
|
@@ -353,7 +446,7 @@ class GitEncoder(nn.Module):
|
|
|
353
446
|
attention_mask,
|
|
354
447
|
past_key_values,
|
|
355
448
|
output_attentions,
|
|
356
|
-
|
|
449
|
+
cache_position,
|
|
357
450
|
)
|
|
358
451
|
|
|
359
452
|
hidden_states = layer_outputs[0]
|
|
@@ -396,6 +489,7 @@ class GitPreTrainedModel(PreTrainedModel):
|
|
|
396
489
|
init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
|
|
397
490
|
init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
|
|
398
491
|
init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
|
|
492
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
399
493
|
if isinstance(module, nn.Linear):
|
|
400
494
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
401
495
|
if module.bias is not None:
|
|
@@ -408,6 +502,8 @@ class GitPreTrainedModel(PreTrainedModel):
|
|
|
408
502
|
elif isinstance(module, nn.LayerNorm):
|
|
409
503
|
init.zeros_(module.bias)
|
|
410
504
|
init.ones_(module.weight)
|
|
505
|
+
elif isinstance(module, GitEmbeddings):
|
|
506
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
411
507
|
|
|
412
508
|
|
|
413
509
|
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
|
|
@@ -903,62 +999,6 @@ class GitModel(GitPreTrainedModel):
|
|
|
903
999
|
def set_input_embeddings(self, value):
|
|
904
1000
|
self.embeddings.word_embeddings = value
|
|
905
1001
|
|
|
906
|
-
def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
|
907
|
-
# Default mask is for forward direction. Flip for backward direction.
|
|
908
|
-
mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
|
|
909
|
-
mask = mask.masked_fill(mask == 1, float("-inf"))
|
|
910
|
-
return mask
|
|
911
|
-
|
|
912
|
-
def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None):
|
|
913
|
-
num_tgt = tgt.shape[1]
|
|
914
|
-
num_memory = memory.shape[1]
|
|
915
|
-
device = tgt.device
|
|
916
|
-
dtype = tgt.dtype
|
|
917
|
-
top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
|
|
918
|
-
top_right = torch.full(
|
|
919
|
-
(num_memory, num_tgt + past_key_values_length),
|
|
920
|
-
float("-inf"),
|
|
921
|
-
device=tgt.device,
|
|
922
|
-
dtype=dtype,
|
|
923
|
-
)
|
|
924
|
-
bottom_left = torch.zeros(
|
|
925
|
-
(num_tgt, num_memory),
|
|
926
|
-
dtype=dtype,
|
|
927
|
-
device=tgt_mask.device,
|
|
928
|
-
)
|
|
929
|
-
|
|
930
|
-
if past_key_values_length > 0:
|
|
931
|
-
tgt_mask = torch.zeros(
|
|
932
|
-
(tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
|
|
933
|
-
dtype=dtype,
|
|
934
|
-
device=tgt_mask.device,
|
|
935
|
-
)
|
|
936
|
-
|
|
937
|
-
left = torch.cat((top_left, bottom_left), dim=0)
|
|
938
|
-
right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
|
|
939
|
-
|
|
940
|
-
full_attention_mask = torch.cat((left, right), dim=1)[None, :]
|
|
941
|
-
|
|
942
|
-
if memory_key_padding_mask is None:
|
|
943
|
-
memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
|
|
944
|
-
# if it is False, it means valid. That is, it is not a padding
|
|
945
|
-
if memory_key_padding_mask.dtype != torch.bool:
|
|
946
|
-
raise ValueError("Memory key padding mask must be a boolean tensor.")
|
|
947
|
-
zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
|
|
948
|
-
zero_negative_infinity[memory_key_padding_mask] = float("-inf")
|
|
949
|
-
full_attention_mask = full_attention_mask.expand(
|
|
950
|
-
(memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt)
|
|
951
|
-
)
|
|
952
|
-
full_attention_mask = full_attention_mask.clone()
|
|
953
|
-
origin_left = full_attention_mask[:, :, :num_memory]
|
|
954
|
-
update = zero_negative_infinity[:, None, :]
|
|
955
|
-
full_attention_mask[:, :, :num_memory] = origin_left + update
|
|
956
|
-
|
|
957
|
-
# add axis for multi-head
|
|
958
|
-
full_attention_mask = full_attention_mask[:, None, :, :]
|
|
959
|
-
|
|
960
|
-
return full_attention_mask
|
|
961
|
-
|
|
962
1002
|
@auto_docstring
|
|
963
1003
|
def forward(
|
|
964
1004
|
self,
|
|
@@ -973,6 +1013,7 @@ class GitModel(GitPreTrainedModel):
|
|
|
973
1013
|
output_hidden_states: Optional[bool] = None,
|
|
974
1014
|
interpolate_pos_encoding: bool = False,
|
|
975
1015
|
return_dict: Optional[bool] = None,
|
|
1016
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
976
1017
|
**kwargs,
|
|
977
1018
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
|
|
978
1019
|
r"""
|
|
@@ -1005,15 +1046,6 @@ class GitModel(GitPreTrainedModel):
|
|
|
1005
1046
|
|
|
1006
1047
|
if input_ids is not None and inputs_embeds is not None:
|
|
1007
1048
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
1008
|
-
elif input_ids is not None:
|
|
1009
|
-
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
1010
|
-
input_shape = input_ids.size()
|
|
1011
|
-
elif inputs_embeds is not None:
|
|
1012
|
-
input_shape = inputs_embeds.size()[:-1]
|
|
1013
|
-
else:
|
|
1014
|
-
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
1015
|
-
|
|
1016
|
-
seq_length = input_shape[1]
|
|
1017
1049
|
|
|
1018
1050
|
# past_key_values_length
|
|
1019
1051
|
past_key_values_length = 0
|
|
@@ -1024,7 +1056,23 @@ class GitModel(GitPreTrainedModel):
|
|
|
1024
1056
|
else past_key_values.get_seq_length()
|
|
1025
1057
|
)
|
|
1026
1058
|
|
|
1027
|
-
|
|
1059
|
+
embedding_output = self.embeddings(
|
|
1060
|
+
input_ids=input_ids,
|
|
1061
|
+
position_ids=position_ids,
|
|
1062
|
+
inputs_embeds=inputs_embeds,
|
|
1063
|
+
past_key_values_length=past_key_values_length,
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
if cache_position is None:
|
|
1067
|
+
cache_position = torch.arange(
|
|
1068
|
+
past_key_values_length,
|
|
1069
|
+
past_key_values_length + embedding_output.shape[1],
|
|
1070
|
+
device=embedding_output.device,
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
# Always create `token_type_ids` so we can re-use Gemma3 style mask preparation fn
|
|
1074
|
+
token_type_ids = torch.zeros_like(embedding_output, dtype=torch.int)[..., 0]
|
|
1075
|
+
|
|
1028
1076
|
if pixel_values is not None:
|
|
1029
1077
|
if pixel_values.ndim == 4:
|
|
1030
1078
|
# here we assume pixel_values is of shape (batch_size, num_channels, height, width)
|
|
@@ -1050,60 +1098,54 @@ class GitModel(GitPreTrainedModel):
|
|
|
1050
1098
|
|
|
1051
1099
|
projected_visual_features = self.visual_projection(visual_features)
|
|
1052
1100
|
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
inputs_embeds=inputs_embeds,
|
|
1057
|
-
past_key_values_length=past_key_values_length,
|
|
1058
|
-
)
|
|
1059
|
-
|
|
1060
|
-
if projected_visual_features is None:
|
|
1061
|
-
projected_visual_features = torch.zeros(
|
|
1062
|
-
(embedding_output.shape[0], 0, embedding_output.shape[2]),
|
|
1063
|
-
dtype=embedding_output.dtype,
|
|
1064
|
-
device=embedding_output.device,
|
|
1101
|
+
# Repeat visual features to match embedding batch size.
|
|
1102
|
+
projected_visual_features = projected_visual_features.repeat(
|
|
1103
|
+
embedding_output.size(0) // projected_visual_features.size(0), 1, 1
|
|
1065
1104
|
)
|
|
1066
1105
|
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1106
|
+
# concatenate patch token and text token embeddings
|
|
1107
|
+
embedding_output = torch.cat((projected_visual_features, embedding_output), dim=1)
|
|
1108
|
+
image_token_type_ids = torch.ones_like(projected_visual_features, dtype=torch.int)[..., 0]
|
|
1109
|
+
token_type_ids = torch.cat([image_token_type_ids, token_type_ids], dim=-1)
|
|
1110
|
+
cache_position = torch.arange(embedding_output.shape[1], device=embedding_output.device, dtype=torch.int)
|
|
1111
|
+
if attention_mask is not None:
|
|
1112
|
+
attention_mask = torch.cat([torch.ones_like(image_token_type_ids), attention_mask], dim=-1)
|
|
1113
|
+
elif past_key_values is not None and input_ids.shape[1] == 1:
|
|
1114
|
+
# Expand attention mask and cache position with image tokens because GIT doesn't add image
|
|
1115
|
+
# placeholder tokens when processing. Doesn't worth the refactor, low usage!
|
|
1116
|
+
cache_position = torch.tensor(
|
|
1117
|
+
[past_key_values_length], dtype=cache_position.dtype, device=cache_position.device
|
|
1118
|
+
)
|
|
1119
|
+
extended_attention_mask = torch.ones(
|
|
1120
|
+
(attention_mask.shape[0], past_key_values_length - attention_mask.shape[1] + 1),
|
|
1121
|
+
dtype=attention_mask.dtype,
|
|
1122
|
+
device=attention_mask.device,
|
|
1123
|
+
)
|
|
1124
|
+
attention_mask = torch.cat([extended_attention_mask, attention_mask], dim=-1)
|
|
1078
1125
|
|
|
1079
|
-
#
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1126
|
+
# Images attend each other bidirectionally while text remains causal
|
|
1127
|
+
causal_mask = create_causal_mask_mapping(
|
|
1128
|
+
self.config,
|
|
1129
|
+
embedding_output,
|
|
1130
|
+
attention_mask,
|
|
1131
|
+
cache_position,
|
|
1132
|
+
past_key_values,
|
|
1133
|
+
None,
|
|
1134
|
+
token_type_ids,
|
|
1135
|
+
pixel_values,
|
|
1085
1136
|
)
|
|
1086
1137
|
|
|
1087
|
-
|
|
1088
|
-
# if the user provides an attention mask, we add it to the default one
|
|
1089
|
-
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
1090
|
-
expanded_attn_mask = _prepare_4d_attention_mask(
|
|
1091
|
-
attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
|
|
1092
|
-
).to(embedding_output.device)
|
|
1093
|
-
if past_key_values_length > 0:
|
|
1094
|
-
expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
|
|
1095
|
-
else:
|
|
1096
|
-
combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask
|
|
1138
|
+
hidden_states = embedding_output
|
|
1097
1139
|
|
|
1098
1140
|
encoder_outputs = self.encoder(
|
|
1099
1141
|
hidden_states,
|
|
1100
|
-
attention_mask=
|
|
1142
|
+
attention_mask=causal_mask,
|
|
1101
1143
|
past_key_values=past_key_values,
|
|
1102
1144
|
use_cache=use_cache,
|
|
1103
1145
|
output_attentions=output_attentions,
|
|
1104
1146
|
output_hidden_states=output_hidden_states,
|
|
1105
1147
|
return_dict=return_dict,
|
|
1106
|
-
|
|
1148
|
+
cache_position=cache_position,
|
|
1107
1149
|
)
|
|
1108
1150
|
sequence_output = encoder_outputs[0]
|
|
1109
1151
|
|
|
@@ -1157,6 +1199,7 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
|
|
|
1157
1199
|
interpolate_pos_encoding: bool = False,
|
|
1158
1200
|
return_dict: Optional[bool] = None,
|
|
1159
1201
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1202
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
1160
1203
|
**kwargs,
|
|
1161
1204
|
) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]:
|
|
1162
1205
|
r"""
|
|
@@ -1306,6 +1349,7 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
|
|
|
1306
1349
|
output_hidden_states=output_hidden_states,
|
|
1307
1350
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
1308
1351
|
return_dict=return_dict,
|
|
1352
|
+
cache_position=cache_position,
|
|
1309
1353
|
)
|
|
1310
1354
|
|
|
1311
1355
|
hidden_states = outputs[0]
|
|
@@ -1339,7 +1383,15 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
|
|
|
1339
1383
|
)
|
|
1340
1384
|
|
|
1341
1385
|
def prepare_inputs_for_generation(
|
|
1342
|
-
self,
|
|
1386
|
+
self,
|
|
1387
|
+
input_ids,
|
|
1388
|
+
past_key_values=None,
|
|
1389
|
+
pixel_values=None,
|
|
1390
|
+
attention_mask=None,
|
|
1391
|
+
use_cache=None,
|
|
1392
|
+
cache_position=None,
|
|
1393
|
+
is_first_iteration=False,
|
|
1394
|
+
**kwargs,
|
|
1343
1395
|
):
|
|
1344
1396
|
# Overwritten -- `git` has special cache handling and doesn't support generating from `inputs_embeds` atm
|
|
1345
1397
|
|
|
@@ -1364,11 +1416,14 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
|
|
|
1364
1416
|
model_inputs = {
|
|
1365
1417
|
"input_ids": input_ids,
|
|
1366
1418
|
"attention_mask": attention_mask,
|
|
1367
|
-
"pixel_values": kwargs.get("pixel_values"),
|
|
1368
1419
|
"past_key_values": past_key_values,
|
|
1369
1420
|
"use_cache": use_cache,
|
|
1421
|
+
"cache_position": cache_position,
|
|
1370
1422
|
}
|
|
1371
1423
|
|
|
1424
|
+
if is_first_iteration or not use_cache:
|
|
1425
|
+
model_inputs["pixel_values"] = pixel_values
|
|
1426
|
+
|
|
1372
1427
|
# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
|
1373
1428
|
for key, value in kwargs.items():
|
|
1374
1429
|
if key not in model_inputs:
|
|
@@ -79,7 +79,7 @@ class GlmRotaryEmbedding(nn.Module):
|
|
|
79
79
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
80
80
|
|
|
81
81
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
82
|
-
self.original_inv_freq =
|
|
82
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
83
83
|
|
|
84
84
|
@staticmethod
|
|
85
85
|
def compute_default_rope_parameters(
|
|
@@ -284,7 +284,7 @@ class Glm4RotaryEmbedding(nn.Module):
|
|
|
284
284
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
285
285
|
|
|
286
286
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
287
|
-
self.original_inv_freq =
|
|
287
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
288
288
|
|
|
289
289
|
@staticmethod
|
|
290
290
|
def compute_default_rope_parameters(
|
|
@@ -354,7 +354,6 @@ class Glm46VImageProcessor(BaseImageProcessor):
|
|
|
354
354
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
|
355
355
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
|
356
356
|
`True`.
|
|
357
|
-
The max pixels of the image to resize the image.
|
|
358
357
|
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
|
359
358
|
The spatial patch size of the vision encoder.
|
|
360
359
|
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
|
|
@@ -381,12 +380,9 @@ class Glm46VImageProcessor(BaseImageProcessor):
|
|
|
381
380
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
|
382
381
|
|
|
383
382
|
"""
|
|
384
|
-
# Try to use config values if set, otherwise fallback to global defaults
|
|
385
383
|
size = size if size is not None else self.size
|
|
386
384
|
if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
|
|
387
385
|
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
|
388
|
-
elif size is None:
|
|
389
|
-
size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000}
|
|
390
386
|
|
|
391
387
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
|
392
388
|
resample = resample if resample is not None else self.resample
|
|
@@ -639,6 +639,7 @@ class Glm46VForConditionalGeneration(Glm46VPreTrainedModel, GenerationMixin):
|
|
|
639
639
|
pixel_values_videos=None,
|
|
640
640
|
image_grid_thw=None,
|
|
641
641
|
video_grid_thw=None,
|
|
642
|
+
is_first_iteration=False,
|
|
642
643
|
**kwargs,
|
|
643
644
|
):
|
|
644
645
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -655,13 +656,14 @@ class Glm46VForConditionalGeneration(Glm46VPreTrainedModel, GenerationMixin):
|
|
|
655
656
|
image_grid_thw=image_grid_thw,
|
|
656
657
|
video_grid_thw=video_grid_thw,
|
|
657
658
|
use_cache=use_cache,
|
|
659
|
+
is_first_iteration=is_first_iteration,
|
|
658
660
|
**kwargs,
|
|
659
661
|
)
|
|
660
662
|
|
|
661
663
|
# GLM-4.1V position_ids are prepareed with rope_deltas in forward
|
|
662
664
|
model_inputs["position_ids"] = None
|
|
663
665
|
|
|
664
|
-
if
|
|
666
|
+
if not is_first_iteration and use_cache:
|
|
665
667
|
model_inputs["pixel_values"] = None
|
|
666
668
|
model_inputs["pixel_values_videos"] = None
|
|
667
669
|
|
|
@@ -110,6 +110,9 @@ class Glm46VPreTrainedModel(Glm4vPreTrainedModel):
|
|
|
110
110
|
_can_record_outputs = None
|
|
111
111
|
_no_split_modules = None
|
|
112
112
|
|
|
113
|
+
def _init_weights(self, module):
|
|
114
|
+
raise AttributeError("Not needed")
|
|
115
|
+
|
|
113
116
|
|
|
114
117
|
class Glm46VModel(Glm4vModel):
|
|
115
118
|
_no_split_modules = None
|
|
@@ -30,7 +30,7 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
33
|
+
from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask
|
|
35
35
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
36
36
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -38,7 +38,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
|
38
38
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
39
39
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
40
40
|
from ...processing_utils import Unpack
|
|
41
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
41
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
42
42
|
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
43
43
|
from .configuration_glm4_moe import Glm4MoeConfig
|
|
44
44
|
|
|
@@ -60,7 +60,7 @@ class Glm4MoeRotaryEmbedding(nn.Module):
|
|
|
60
60
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
61
61
|
|
|
62
62
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
63
|
-
self.original_inv_freq =
|
|
63
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
64
64
|
|
|
65
65
|
@staticmethod
|
|
66
66
|
def compute_default_rope_parameters(
|
|
@@ -332,6 +332,7 @@ class Glm4MoeRMSNorm(nn.Module):
|
|
|
332
332
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
|
333
333
|
|
|
334
334
|
|
|
335
|
+
@use_experts_implementation
|
|
335
336
|
class Glm4MoeNaiveMoe(nn.Module):
|
|
336
337
|
"""Collection of expert weights stored as 3D tensors."""
|
|
337
338
|
|
|
@@ -339,7 +340,7 @@ class Glm4MoeNaiveMoe(nn.Module):
|
|
|
339
340
|
super().__init__()
|
|
340
341
|
self.num_experts = config.num_local_experts
|
|
341
342
|
self.hidden_dim = config.hidden_size
|
|
342
|
-
self.intermediate_dim = config.
|
|
343
|
+
self.intermediate_dim = config.moe_intermediate_size
|
|
343
344
|
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
|
344
345
|
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
|
345
346
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
@@ -486,7 +487,9 @@ class Glm4MoePreTrainedModel(PreTrainedModel):
|
|
|
486
487
|
_supports_flash_attn = True
|
|
487
488
|
_supports_sdpa = True
|
|
488
489
|
_supports_flex_attn = True
|
|
489
|
-
_can_compile_fullgraph =
|
|
490
|
+
_can_compile_fullgraph = (
|
|
491
|
+
is_grouped_mm_available()
|
|
492
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
490
493
|
_supports_attention_backend = True
|
|
491
494
|
_can_record_outputs = {
|
|
492
495
|
"hidden_states": Glm4MoeDecoderLayer,
|
|
@@ -499,6 +502,7 @@ class Glm4MoePreTrainedModel(PreTrainedModel):
|
|
|
499
502
|
super()._init_weights(module)
|
|
500
503
|
if isinstance(module, Glm4MoeTopkRouter):
|
|
501
504
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
505
|
+
init.zeros_(module.e_score_correction_bias)
|
|
502
506
|
elif isinstance(module, Glm4MoeNaiveMoe):
|
|
503
507
|
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
|
|
504
508
|
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
|
|
@@ -353,7 +353,6 @@ class Glm4vImageProcessor(BaseImageProcessor):
|
|
|
353
353
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
|
354
354
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
|
355
355
|
`True`.
|
|
356
|
-
The max pixels of the image to resize the image.
|
|
357
356
|
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
|
358
357
|
The spatial patch size of the vision encoder.
|
|
359
358
|
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
|
|
@@ -380,12 +379,9 @@ class Glm4vImageProcessor(BaseImageProcessor):
|
|
|
380
379
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
|
381
380
|
|
|
382
381
|
"""
|
|
383
|
-
# Try to use config values if set, otherwise fallback to global defaults
|
|
384
382
|
size = size if size is not None else self.size
|
|
385
383
|
if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
|
|
386
384
|
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
|
387
|
-
elif size is None:
|
|
388
|
-
size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000}
|
|
389
385
|
|
|
390
386
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
|
391
387
|
resample = resample if resample is not None else self.resample
|