transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -26,8 +26,9 @@ from torch import nn
|
|
|
26
26
|
from ... import initialization as init
|
|
27
27
|
from ...activations import ACT2FN
|
|
28
28
|
from ...cache_utils import Cache, DynamicCache
|
|
29
|
+
from ...configuration_utils import PreTrainedConfig
|
|
29
30
|
from ...generation import GenerationMixin
|
|
30
|
-
from ...
|
|
31
|
+
from ...masking_utils import create_masks_for_generate
|
|
31
32
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
32
33
|
from ...modeling_outputs import (
|
|
33
34
|
BaseModelOutput,
|
|
@@ -69,6 +70,104 @@ class GitVisionModelOutput(ModelOutput):
|
|
|
69
70
|
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
|
70
71
|
|
|
71
72
|
|
|
73
|
+
# Copied from transformers.models.gemma3.modeling_gemma3.token_type_ids_mask_function
|
|
74
|
+
def token_type_ids_mask_function(
|
|
75
|
+
token_type_ids: Optional[torch.Tensor],
|
|
76
|
+
image_group_ids: Optional[torch.Tensor],
|
|
77
|
+
) -> Optional[Callable]:
|
|
78
|
+
"""
|
|
79
|
+
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
|
80
|
+
not start and end indices.
|
|
81
|
+
"""
|
|
82
|
+
# Do not return an additional mask in this case
|
|
83
|
+
if token_type_ids is None:
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
|
87
|
+
# If it's 1 for both query and key/value, we are in an image block
|
|
88
|
+
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
|
|
89
|
+
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
|
|
90
|
+
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
|
|
91
|
+
safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
|
92
|
+
|
|
93
|
+
token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
|
|
94
|
+
token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
|
|
95
|
+
|
|
96
|
+
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
|
|
97
|
+
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
|
|
98
|
+
|
|
99
|
+
image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
|
|
100
|
+
image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
|
|
101
|
+
|
|
102
|
+
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
|
|
103
|
+
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
|
|
104
|
+
|
|
105
|
+
is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
|
|
106
|
+
same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
|
|
107
|
+
|
|
108
|
+
# This is bidirectional attention whenever we are dealing with image tokens
|
|
109
|
+
return is_image_block & same_image_block
|
|
110
|
+
|
|
111
|
+
return inner_mask
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# Copied from transformers.models.gemma3.modeling_gemma3.create_causal_mask_mapping
|
|
115
|
+
def create_causal_mask_mapping(
|
|
116
|
+
config: PreTrainedConfig,
|
|
117
|
+
input_embeds: torch.Tensor,
|
|
118
|
+
attention_mask: Optional[torch.Tensor],
|
|
119
|
+
cache_position: torch.Tensor,
|
|
120
|
+
past_key_values: Optional[Cache],
|
|
121
|
+
position_ids: Optional[torch.Tensor],
|
|
122
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
123
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
124
|
+
is_training: bool = False,
|
|
125
|
+
is_first_iteration: Optional[bool] = None,
|
|
126
|
+
**kwargs,
|
|
127
|
+
) -> dict:
|
|
128
|
+
"""
|
|
129
|
+
Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
|
|
130
|
+
for all kinds of forward passes. Gemma3 uses a bidirectional mask for images.
|
|
131
|
+
|
|
132
|
+
Uses `pixel_values` as an optional input to disambiguate edge cases.
|
|
133
|
+
"""
|
|
134
|
+
if is_training and token_type_ids is None:
|
|
135
|
+
raise ValueError("`token_type_ids` is required as a model input when training")
|
|
136
|
+
|
|
137
|
+
mask_kwargs = {
|
|
138
|
+
"config": config.get_text_config(),
|
|
139
|
+
"input_embeds": input_embeds,
|
|
140
|
+
"attention_mask": attention_mask,
|
|
141
|
+
"cache_position": cache_position,
|
|
142
|
+
"past_key_values": past_key_values,
|
|
143
|
+
"position_ids": position_ids,
|
|
144
|
+
}
|
|
145
|
+
# NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
|
|
146
|
+
# (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
|
|
147
|
+
# means). Determining prefill in that case requires checking data values, which is not compile-compatible.
|
|
148
|
+
is_first_iteration = (
|
|
149
|
+
is_first_iteration
|
|
150
|
+
if is_first_iteration is not None
|
|
151
|
+
else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
|
|
152
|
+
)
|
|
153
|
+
if token_type_ids is not None and is_first_iteration:
|
|
154
|
+
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
|
|
155
|
+
# undo the causal masking)
|
|
156
|
+
|
|
157
|
+
# First find where a new image block starts: 1 if image and previous not image
|
|
158
|
+
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
|
159
|
+
is_image = (token_type_ids == 1).to(cache_position.device)
|
|
160
|
+
is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
|
161
|
+
new_image_start = is_image & ~is_previous_image
|
|
162
|
+
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
|
163
|
+
image_group_ids = torch.where(is_image, image_group_ids, -1)
|
|
164
|
+
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
|
165
|
+
token_type_ids.to(cache_position.device), image_group_ids
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
return create_masks_for_generate(**mask_kwargs)
|
|
169
|
+
|
|
170
|
+
|
|
72
171
|
class GitEmbeddings(nn.Module):
|
|
73
172
|
"""Construct the embeddings from word and position embeddings."""
|
|
74
173
|
|
|
@@ -148,17 +247,15 @@ class GitSelfAttention(nn.Module):
|
|
|
148
247
|
hidden_states: torch.Tensor,
|
|
149
248
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
150
249
|
past_key_values: Optional[Cache] = None,
|
|
151
|
-
|
|
152
|
-
pixel_values_present: Optional[bool] = False,
|
|
250
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
153
251
|
) -> tuple[torch.Tensor]:
|
|
154
|
-
batch_size
|
|
252
|
+
batch_size = hidden_states.shape[0]
|
|
155
253
|
query_layer = (
|
|
156
254
|
self.query(hidden_states)
|
|
157
255
|
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
|
158
256
|
.transpose(1, 2)
|
|
159
257
|
)
|
|
160
258
|
|
|
161
|
-
cutoff = self.image_patch_tokens if pixel_values_present else 0
|
|
162
259
|
key_layer = (
|
|
163
260
|
self.key(hidden_states)
|
|
164
261
|
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
|
@@ -170,12 +267,9 @@ class GitSelfAttention(nn.Module):
|
|
|
170
267
|
.transpose(1, 2)
|
|
171
268
|
)
|
|
172
269
|
if past_key_values is not None:
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
key_layer[:, :, cutoff:, :], value_layer[:, :, cutoff:, :], self.layer_idx
|
|
270
|
+
key_layer, value_layer = past_key_values.update(
|
|
271
|
+
key_layer, value_layer, self.layer_idx, cache_kwargs={"cache_position": cache_position}
|
|
176
272
|
)
|
|
177
|
-
key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2)
|
|
178
|
-
value_layer = torch.cat([value_layer[:, :, :cutoff, :], value_layer_past], dim=2)
|
|
179
273
|
|
|
180
274
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
181
275
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
|
@@ -232,15 +326,14 @@ class GitAttention(nn.Module):
|
|
|
232
326
|
hidden_states: torch.Tensor,
|
|
233
327
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
234
328
|
past_key_values: Optional[Cache] = None,
|
|
329
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
235
330
|
output_attentions: Optional[bool] = False,
|
|
236
|
-
pixel_values_present: Optional[bool] = False,
|
|
237
331
|
) -> tuple[torch.Tensor]:
|
|
238
332
|
attn_output, self_attn_weights = self.self(
|
|
239
333
|
hidden_states,
|
|
240
334
|
attention_mask,
|
|
241
335
|
past_key_values,
|
|
242
|
-
|
|
243
|
-
pixel_values_present,
|
|
336
|
+
cache_position=cache_position,
|
|
244
337
|
)
|
|
245
338
|
attention_output = self.output(attn_output, hidden_states)
|
|
246
339
|
return attention_output, self_attn_weights
|
|
@@ -291,8 +384,8 @@ class GitLayer(GradientCheckpointingLayer):
|
|
|
291
384
|
hidden_states: torch.Tensor,
|
|
292
385
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
293
386
|
past_key_values: Optional[Cache] = None,
|
|
387
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
294
388
|
output_attentions: Optional[bool] = False,
|
|
295
|
-
pixel_values_present: Optional[bool] = False,
|
|
296
389
|
) -> tuple[torch.Tensor]:
|
|
297
390
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
|
298
391
|
attention_output, self_attention_weights = self.attention(
|
|
@@ -300,7 +393,7 @@ class GitLayer(GradientCheckpointingLayer):
|
|
|
300
393
|
attention_mask,
|
|
301
394
|
output_attentions=output_attentions,
|
|
302
395
|
past_key_values=past_key_values,
|
|
303
|
-
|
|
396
|
+
cache_position=cache_position,
|
|
304
397
|
)
|
|
305
398
|
|
|
306
399
|
layer_output = apply_chunking_to_forward(
|
|
@@ -329,8 +422,8 @@ class GitEncoder(nn.Module):
|
|
|
329
422
|
use_cache: Optional[bool] = None,
|
|
330
423
|
output_attentions: Optional[bool] = False,
|
|
331
424
|
output_hidden_states: Optional[bool] = False,
|
|
332
|
-
pixel_values_present: Optional[bool] = False,
|
|
333
425
|
return_dict: Optional[bool] = True,
|
|
426
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
334
427
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPast]:
|
|
335
428
|
if self.gradient_checkpointing and self.training:
|
|
336
429
|
if use_cache:
|
|
@@ -353,7 +446,7 @@ class GitEncoder(nn.Module):
|
|
|
353
446
|
attention_mask,
|
|
354
447
|
past_key_values,
|
|
355
448
|
output_attentions,
|
|
356
|
-
|
|
449
|
+
cache_position,
|
|
357
450
|
)
|
|
358
451
|
|
|
359
452
|
hidden_states = layer_outputs[0]
|
|
@@ -396,6 +489,7 @@ class GitPreTrainedModel(PreTrainedModel):
|
|
|
396
489
|
init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
|
|
397
490
|
init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
|
|
398
491
|
init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
|
|
492
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
399
493
|
if isinstance(module, nn.Linear):
|
|
400
494
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
401
495
|
if module.bias is not None:
|
|
@@ -408,6 +502,8 @@ class GitPreTrainedModel(PreTrainedModel):
|
|
|
408
502
|
elif isinstance(module, nn.LayerNorm):
|
|
409
503
|
init.zeros_(module.bias)
|
|
410
504
|
init.ones_(module.weight)
|
|
505
|
+
elif isinstance(module, GitEmbeddings):
|
|
506
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
411
507
|
|
|
412
508
|
|
|
413
509
|
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
|
|
@@ -827,6 +923,7 @@ class GitVisionModel(GitPreTrainedModel):
|
|
|
827
923
|
output_hidden_states: Optional[bool] = None,
|
|
828
924
|
interpolate_pos_encoding: bool = False,
|
|
829
925
|
return_dict: Optional[bool] = None,
|
|
926
|
+
**kwargs,
|
|
830
927
|
) -> Union[tuple, BaseModelOutput]:
|
|
831
928
|
r"""
|
|
832
929
|
Examples:
|
|
@@ -902,62 +999,6 @@ class GitModel(GitPreTrainedModel):
|
|
|
902
999
|
def set_input_embeddings(self, value):
|
|
903
1000
|
self.embeddings.word_embeddings = value
|
|
904
1001
|
|
|
905
|
-
def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
|
906
|
-
# Default mask is for forward direction. Flip for backward direction.
|
|
907
|
-
mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
|
|
908
|
-
mask = mask.masked_fill(mask == 1, float("-inf"))
|
|
909
|
-
return mask
|
|
910
|
-
|
|
911
|
-
def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None):
|
|
912
|
-
num_tgt = tgt.shape[1]
|
|
913
|
-
num_memory = memory.shape[1]
|
|
914
|
-
device = tgt.device
|
|
915
|
-
dtype = tgt.dtype
|
|
916
|
-
top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
|
|
917
|
-
top_right = torch.full(
|
|
918
|
-
(num_memory, num_tgt + past_key_values_length),
|
|
919
|
-
float("-inf"),
|
|
920
|
-
device=tgt.device,
|
|
921
|
-
dtype=dtype,
|
|
922
|
-
)
|
|
923
|
-
bottom_left = torch.zeros(
|
|
924
|
-
(num_tgt, num_memory),
|
|
925
|
-
dtype=dtype,
|
|
926
|
-
device=tgt_mask.device,
|
|
927
|
-
)
|
|
928
|
-
|
|
929
|
-
if past_key_values_length > 0:
|
|
930
|
-
tgt_mask = torch.zeros(
|
|
931
|
-
(tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
|
|
932
|
-
dtype=dtype,
|
|
933
|
-
device=tgt_mask.device,
|
|
934
|
-
)
|
|
935
|
-
|
|
936
|
-
left = torch.cat((top_left, bottom_left), dim=0)
|
|
937
|
-
right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
|
|
938
|
-
|
|
939
|
-
full_attention_mask = torch.cat((left, right), dim=1)[None, :]
|
|
940
|
-
|
|
941
|
-
if memory_key_padding_mask is None:
|
|
942
|
-
memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
|
|
943
|
-
# if it is False, it means valid. That is, it is not a padding
|
|
944
|
-
if memory_key_padding_mask.dtype != torch.bool:
|
|
945
|
-
raise ValueError("Memory key padding mask must be a boolean tensor.")
|
|
946
|
-
zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
|
|
947
|
-
zero_negative_infinity[memory_key_padding_mask] = float("-inf")
|
|
948
|
-
full_attention_mask = full_attention_mask.expand(
|
|
949
|
-
(memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt)
|
|
950
|
-
)
|
|
951
|
-
full_attention_mask = full_attention_mask.clone()
|
|
952
|
-
origin_left = full_attention_mask[:, :, :num_memory]
|
|
953
|
-
update = zero_negative_infinity[:, None, :]
|
|
954
|
-
full_attention_mask[:, :, :num_memory] = origin_left + update
|
|
955
|
-
|
|
956
|
-
# add axis for multi-head
|
|
957
|
-
full_attention_mask = full_attention_mask[:, None, :, :]
|
|
958
|
-
|
|
959
|
-
return full_attention_mask
|
|
960
|
-
|
|
961
1002
|
@auto_docstring
|
|
962
1003
|
def forward(
|
|
963
1004
|
self,
|
|
@@ -972,6 +1013,8 @@ class GitModel(GitPreTrainedModel):
|
|
|
972
1013
|
output_hidden_states: Optional[bool] = None,
|
|
973
1014
|
interpolate_pos_encoding: bool = False,
|
|
974
1015
|
return_dict: Optional[bool] = None,
|
|
1016
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
1017
|
+
**kwargs,
|
|
975
1018
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
|
|
976
1019
|
r"""
|
|
977
1020
|
Examples:
|
|
@@ -1003,15 +1046,6 @@ class GitModel(GitPreTrainedModel):
|
|
|
1003
1046
|
|
|
1004
1047
|
if input_ids is not None and inputs_embeds is not None:
|
|
1005
1048
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
1006
|
-
elif input_ids is not None:
|
|
1007
|
-
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
1008
|
-
input_shape = input_ids.size()
|
|
1009
|
-
elif inputs_embeds is not None:
|
|
1010
|
-
input_shape = inputs_embeds.size()[:-1]
|
|
1011
|
-
else:
|
|
1012
|
-
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
1013
|
-
|
|
1014
|
-
seq_length = input_shape[1]
|
|
1015
1049
|
|
|
1016
1050
|
# past_key_values_length
|
|
1017
1051
|
past_key_values_length = 0
|
|
@@ -1022,7 +1056,23 @@ class GitModel(GitPreTrainedModel):
|
|
|
1022
1056
|
else past_key_values.get_seq_length()
|
|
1023
1057
|
)
|
|
1024
1058
|
|
|
1025
|
-
|
|
1059
|
+
embedding_output = self.embeddings(
|
|
1060
|
+
input_ids=input_ids,
|
|
1061
|
+
position_ids=position_ids,
|
|
1062
|
+
inputs_embeds=inputs_embeds,
|
|
1063
|
+
past_key_values_length=past_key_values_length,
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
if cache_position is None:
|
|
1067
|
+
cache_position = torch.arange(
|
|
1068
|
+
past_key_values_length,
|
|
1069
|
+
past_key_values_length + embedding_output.shape[1],
|
|
1070
|
+
device=embedding_output.device,
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
# Always create `token_type_ids` so we can re-use Gemma3 style mask preparation fn
|
|
1074
|
+
token_type_ids = torch.zeros_like(embedding_output, dtype=torch.int)[..., 0]
|
|
1075
|
+
|
|
1026
1076
|
if pixel_values is not None:
|
|
1027
1077
|
if pixel_values.ndim == 4:
|
|
1028
1078
|
# here we assume pixel_values is of shape (batch_size, num_channels, height, width)
|
|
@@ -1048,60 +1098,54 @@ class GitModel(GitPreTrainedModel):
|
|
|
1048
1098
|
|
|
1049
1099
|
projected_visual_features = self.visual_projection(visual_features)
|
|
1050
1100
|
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
inputs_embeds=inputs_embeds,
|
|
1055
|
-
past_key_values_length=past_key_values_length,
|
|
1056
|
-
)
|
|
1057
|
-
|
|
1058
|
-
if projected_visual_features is None:
|
|
1059
|
-
projected_visual_features = torch.zeros(
|
|
1060
|
-
(embedding_output.shape[0], 0, embedding_output.shape[2]),
|
|
1061
|
-
dtype=embedding_output.dtype,
|
|
1062
|
-
device=embedding_output.device,
|
|
1101
|
+
# Repeat visual features to match embedding batch size.
|
|
1102
|
+
projected_visual_features = projected_visual_features.repeat(
|
|
1103
|
+
embedding_output.size(0) // projected_visual_features.size(0), 1, 1
|
|
1063
1104
|
)
|
|
1064
1105
|
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1106
|
+
# concatenate patch token and text token embeddings
|
|
1107
|
+
embedding_output = torch.cat((projected_visual_features, embedding_output), dim=1)
|
|
1108
|
+
image_token_type_ids = torch.ones_like(projected_visual_features, dtype=torch.int)[..., 0]
|
|
1109
|
+
token_type_ids = torch.cat([image_token_type_ids, token_type_ids], dim=-1)
|
|
1110
|
+
cache_position = torch.arange(embedding_output.shape[1], device=embedding_output.device, dtype=torch.int)
|
|
1111
|
+
if attention_mask is not None:
|
|
1112
|
+
attention_mask = torch.cat([torch.ones_like(image_token_type_ids), attention_mask], dim=-1)
|
|
1113
|
+
elif past_key_values is not None and input_ids.shape[1] == 1:
|
|
1114
|
+
# Expand attention mask and cache position with image tokens because GIT doesn't add image
|
|
1115
|
+
# placeholder tokens when processing. Doesn't worth the refactor, low usage!
|
|
1116
|
+
cache_position = torch.tensor(
|
|
1117
|
+
[past_key_values_length], dtype=cache_position.dtype, device=cache_position.device
|
|
1118
|
+
)
|
|
1119
|
+
extended_attention_mask = torch.ones(
|
|
1120
|
+
(attention_mask.shape[0], past_key_values_length - attention_mask.shape[1] + 1),
|
|
1121
|
+
dtype=attention_mask.dtype,
|
|
1122
|
+
device=attention_mask.device,
|
|
1123
|
+
)
|
|
1124
|
+
attention_mask = torch.cat([extended_attention_mask, attention_mask], dim=-1)
|
|
1076
1125
|
|
|
1077
|
-
#
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1126
|
+
# Images attend each other bidirectionally while text remains causal
|
|
1127
|
+
causal_mask = create_causal_mask_mapping(
|
|
1128
|
+
self.config,
|
|
1129
|
+
embedding_output,
|
|
1130
|
+
attention_mask,
|
|
1131
|
+
cache_position,
|
|
1132
|
+
past_key_values,
|
|
1133
|
+
None,
|
|
1134
|
+
token_type_ids,
|
|
1135
|
+
pixel_values,
|
|
1083
1136
|
)
|
|
1084
1137
|
|
|
1085
|
-
|
|
1086
|
-
# if the user provides an attention mask, we add it to the default one
|
|
1087
|
-
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
1088
|
-
expanded_attn_mask = _prepare_4d_attention_mask(
|
|
1089
|
-
attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
|
|
1090
|
-
).to(embedding_output.device)
|
|
1091
|
-
if past_key_values_length > 0:
|
|
1092
|
-
expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
|
|
1093
|
-
else:
|
|
1094
|
-
combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask
|
|
1138
|
+
hidden_states = embedding_output
|
|
1095
1139
|
|
|
1096
1140
|
encoder_outputs = self.encoder(
|
|
1097
1141
|
hidden_states,
|
|
1098
|
-
attention_mask=
|
|
1142
|
+
attention_mask=causal_mask,
|
|
1099
1143
|
past_key_values=past_key_values,
|
|
1100
1144
|
use_cache=use_cache,
|
|
1101
1145
|
output_attentions=output_attentions,
|
|
1102
1146
|
output_hidden_states=output_hidden_states,
|
|
1103
1147
|
return_dict=return_dict,
|
|
1104
|
-
|
|
1148
|
+
cache_position=cache_position,
|
|
1105
1149
|
)
|
|
1106
1150
|
sequence_output = encoder_outputs[0]
|
|
1107
1151
|
|
|
@@ -1155,6 +1199,7 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
|
|
|
1155
1199
|
interpolate_pos_encoding: bool = False,
|
|
1156
1200
|
return_dict: Optional[bool] = None,
|
|
1157
1201
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1202
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
1158
1203
|
**kwargs,
|
|
1159
1204
|
) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]:
|
|
1160
1205
|
r"""
|
|
@@ -1304,6 +1349,7 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
|
|
|
1304
1349
|
output_hidden_states=output_hidden_states,
|
|
1305
1350
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
1306
1351
|
return_dict=return_dict,
|
|
1352
|
+
cache_position=cache_position,
|
|
1307
1353
|
)
|
|
1308
1354
|
|
|
1309
1355
|
hidden_states = outputs[0]
|
|
@@ -1337,7 +1383,15 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
|
|
|
1337
1383
|
)
|
|
1338
1384
|
|
|
1339
1385
|
def prepare_inputs_for_generation(
|
|
1340
|
-
self,
|
|
1386
|
+
self,
|
|
1387
|
+
input_ids,
|
|
1388
|
+
past_key_values=None,
|
|
1389
|
+
pixel_values=None,
|
|
1390
|
+
attention_mask=None,
|
|
1391
|
+
use_cache=None,
|
|
1392
|
+
cache_position=None,
|
|
1393
|
+
is_first_iteration=False,
|
|
1394
|
+
**kwargs,
|
|
1341
1395
|
):
|
|
1342
1396
|
# Overwritten -- `git` has special cache handling and doesn't support generating from `inputs_embeds` atm
|
|
1343
1397
|
|
|
@@ -1362,11 +1416,14 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
|
|
|
1362
1416
|
model_inputs = {
|
|
1363
1417
|
"input_ids": input_ids,
|
|
1364
1418
|
"attention_mask": attention_mask,
|
|
1365
|
-
"pixel_values": kwargs.get("pixel_values"),
|
|
1366
1419
|
"past_key_values": past_key_values,
|
|
1367
1420
|
"use_cache": use_cache,
|
|
1421
|
+
"cache_position": cache_position,
|
|
1368
1422
|
}
|
|
1369
1423
|
|
|
1424
|
+
if is_first_iteration or not use_cache:
|
|
1425
|
+
model_inputs["pixel_values"] = pixel_values
|
|
1426
|
+
|
|
1370
1427
|
# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
|
1371
1428
|
for key, value in kwargs.items():
|
|
1372
1429
|
if key not in model_inputs:
|
|
@@ -28,7 +28,7 @@ import torch.nn as nn
|
|
|
28
28
|
from ...activations import ACT2FN
|
|
29
29
|
from ...cache_utils import Cache, DynamicCache
|
|
30
30
|
from ...generation import GenerationMixin
|
|
31
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
31
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
32
32
|
from ...masking_utils import create_causal_mask
|
|
33
33
|
from ...modeling_layers import (
|
|
34
34
|
GenericForSequenceClassification,
|
|
@@ -40,7 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
40
40
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
41
41
|
from ...processing_utils import Unpack
|
|
42
42
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
43
|
-
from ...utils.generic import check_model_inputs
|
|
43
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
44
44
|
from .configuration_glm import GlmConfig
|
|
45
45
|
|
|
46
46
|
|
|
@@ -79,7 +79,7 @@ class GlmRotaryEmbedding(nn.Module):
|
|
|
79
79
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
80
80
|
|
|
81
81
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
82
|
-
self.original_inv_freq =
|
|
82
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
83
83
|
|
|
84
84
|
@staticmethod
|
|
85
85
|
def compute_default_rope_parameters(
|
|
@@ -120,7 +120,7 @@ class GlmRotaryEmbedding(nn.Module):
|
|
|
120
120
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
121
121
|
|
|
122
122
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
123
|
-
with
|
|
123
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
124
124
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
125
125
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
126
126
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -216,6 +216,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
216
216
|
return q_embed, k_embed
|
|
217
217
|
|
|
218
218
|
|
|
219
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
219
220
|
class GlmAttention(nn.Module):
|
|
220
221
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
221
222
|
|
|
@@ -239,7 +240,6 @@ class GlmAttention(nn.Module):
|
|
|
239
240
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
240
241
|
)
|
|
241
242
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
|
242
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
243
243
|
|
|
244
244
|
def forward(
|
|
245
245
|
self,
|
|
@@ -28,7 +28,7 @@ import torch.nn as nn
|
|
|
28
28
|
from ...activations import ACT2FN
|
|
29
29
|
from ...cache_utils import Cache, DynamicCache
|
|
30
30
|
from ...generation import GenerationMixin
|
|
31
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
31
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
32
32
|
from ...masking_utils import create_causal_mask
|
|
33
33
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
34
34
|
from ...modeling_layers import (
|
|
@@ -41,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
41
41
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
42
42
|
from ...processing_utils import Unpack
|
|
43
43
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
44
|
-
from ...utils.generic import check_model_inputs
|
|
44
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
45
45
|
from .configuration_glm4 import Glm4Config
|
|
46
46
|
|
|
47
47
|
|
|
@@ -198,6 +198,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
198
198
|
return q_embed, k_embed
|
|
199
199
|
|
|
200
200
|
|
|
201
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
201
202
|
class Glm4Attention(nn.Module):
|
|
202
203
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
203
204
|
|
|
@@ -221,7 +222,6 @@ class Glm4Attention(nn.Module):
|
|
|
221
222
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
222
223
|
)
|
|
223
224
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
|
224
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
225
225
|
|
|
226
226
|
def forward(
|
|
227
227
|
self,
|
|
@@ -284,7 +284,7 @@ class Glm4RotaryEmbedding(nn.Module):
|
|
|
284
284
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
285
285
|
|
|
286
286
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
287
|
-
self.original_inv_freq =
|
|
287
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
288
288
|
|
|
289
289
|
@staticmethod
|
|
290
290
|
def compute_default_rope_parameters(
|
|
@@ -325,7 +325,7 @@ class Glm4RotaryEmbedding(nn.Module):
|
|
|
325
325
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
326
326
|
|
|
327
327
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
328
|
-
with
|
|
328
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
329
329
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
330
330
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
331
331
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -354,7 +354,6 @@ class Glm46VImageProcessor(BaseImageProcessor):
|
|
|
354
354
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
|
355
355
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
|
356
356
|
`True`.
|
|
357
|
-
The max pixels of the image to resize the image.
|
|
358
357
|
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
|
359
358
|
The spatial patch size of the vision encoder.
|
|
360
359
|
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
|
|
@@ -381,12 +380,9 @@ class Glm46VImageProcessor(BaseImageProcessor):
|
|
|
381
380
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
|
382
381
|
|
|
383
382
|
"""
|
|
384
|
-
# Try to use config values if set, otherwise fallback to global defaults
|
|
385
383
|
size = size if size is not None else self.size
|
|
386
384
|
if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
|
|
387
385
|
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
|
388
|
-
elif size is None:
|
|
389
|
-
size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000}
|
|
390
386
|
|
|
391
387
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
|
392
388
|
resample = resample if resample is not None else self.resample
|
|
@@ -639,6 +639,7 @@ class Glm46VForConditionalGeneration(Glm46VPreTrainedModel, GenerationMixin):
|
|
|
639
639
|
pixel_values_videos=None,
|
|
640
640
|
image_grid_thw=None,
|
|
641
641
|
video_grid_thw=None,
|
|
642
|
+
is_first_iteration=False,
|
|
642
643
|
**kwargs,
|
|
643
644
|
):
|
|
644
645
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -655,13 +656,14 @@ class Glm46VForConditionalGeneration(Glm46VPreTrainedModel, GenerationMixin):
|
|
|
655
656
|
image_grid_thw=image_grid_thw,
|
|
656
657
|
video_grid_thw=video_grid_thw,
|
|
657
658
|
use_cache=use_cache,
|
|
659
|
+
is_first_iteration=is_first_iteration,
|
|
658
660
|
**kwargs,
|
|
659
661
|
)
|
|
660
662
|
|
|
661
663
|
# GLM-4.1V position_ids are prepareed with rope_deltas in forward
|
|
662
664
|
model_inputs["position_ids"] = None
|
|
663
665
|
|
|
664
|
-
if
|
|
666
|
+
if not is_first_iteration and use_cache:
|
|
665
667
|
model_inputs["pixel_values"] = None
|
|
666
668
|
model_inputs["pixel_values_videos"] = None
|
|
667
669
|
|
|
@@ -110,6 +110,9 @@ class Glm46VPreTrainedModel(Glm4vPreTrainedModel):
|
|
|
110
110
|
_can_record_outputs = None
|
|
111
111
|
_no_split_modules = None
|
|
112
112
|
|
|
113
|
+
def _init_weights(self, module):
|
|
114
|
+
raise AttributeError("Not needed")
|
|
115
|
+
|
|
113
116
|
|
|
114
117
|
class Glm46VModel(Glm4vModel):
|
|
115
118
|
_no_split_modules = None
|