transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -421,40 +421,8 @@ class Swinv2SelfAttention(nn.Module):
|
|
|
421
421
|
nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
|
|
422
422
|
)
|
|
423
423
|
|
|
424
|
-
|
|
425
|
-
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
|
|
426
|
-
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
|
|
427
|
-
relative_coords_table = (
|
|
428
|
-
torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
|
|
429
|
-
.permute(1, 2, 0)
|
|
430
|
-
.contiguous()
|
|
431
|
-
.unsqueeze(0)
|
|
432
|
-
) # [1, 2*window_height - 1, 2*window_width - 1, 2]
|
|
433
|
-
if pretrained_window_size[0] > 0:
|
|
434
|
-
relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
|
|
435
|
-
relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
|
|
436
|
-
elif window_size > 1:
|
|
437
|
-
relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
|
|
438
|
-
relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
|
|
439
|
-
relative_coords_table *= 8 # normalize to -8, 8
|
|
440
|
-
relative_coords_table = (
|
|
441
|
-
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
|
|
442
|
-
)
|
|
443
|
-
# set to same dtype as mlp weight
|
|
444
|
-
relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
|
|
424
|
+
relative_coords_table, relative_position_index = self.create_coords_table_and_index()
|
|
445
425
|
self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
|
|
446
|
-
|
|
447
|
-
# get pair-wise relative position index for each token inside the window
|
|
448
|
-
coords_h = torch.arange(self.window_size[0])
|
|
449
|
-
coords_w = torch.arange(self.window_size[1])
|
|
450
|
-
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
|
451
|
-
coords_flatten = torch.flatten(coords, 1)
|
|
452
|
-
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
453
|
-
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
|
454
|
-
relative_coords[:, :, 0] += self.window_size[0] - 1
|
|
455
|
-
relative_coords[:, :, 1] += self.window_size[1] - 1
|
|
456
|
-
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
|
457
|
-
relative_position_index = relative_coords.sum(-1)
|
|
458
426
|
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
|
|
459
427
|
|
|
460
428
|
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
|
@@ -530,6 +498,43 @@ class Swinv2SelfAttention(nn.Module):
|
|
|
530
498
|
|
|
531
499
|
return outputs
|
|
532
500
|
|
|
501
|
+
def create_coords_table_and_index(self):
|
|
502
|
+
# get relative_coords_table
|
|
503
|
+
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
|
|
504
|
+
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
|
|
505
|
+
relative_coords_table = (
|
|
506
|
+
torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
|
|
507
|
+
.permute(1, 2, 0)
|
|
508
|
+
.contiguous()
|
|
509
|
+
.unsqueeze(0)
|
|
510
|
+
) # [1, 2*window_height - 1, 2*window_width - 1, 2]
|
|
511
|
+
if self.pretrained_window_size[0] > 0:
|
|
512
|
+
relative_coords_table[:, :, :, 0] /= self.pretrained_window_size[0] - 1
|
|
513
|
+
relative_coords_table[:, :, :, 1] /= self.pretrained_window_size[1] - 1
|
|
514
|
+
elif self.window_size[0] > 1:
|
|
515
|
+
relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
|
|
516
|
+
relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
|
|
517
|
+
relative_coords_table *= 8 # normalize to -8, 8
|
|
518
|
+
relative_coords_table = (
|
|
519
|
+
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
|
|
520
|
+
)
|
|
521
|
+
# set to same dtype as mlp weight
|
|
522
|
+
relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
|
|
523
|
+
|
|
524
|
+
# get pair-wise relative position index for each token inside the window
|
|
525
|
+
coords_h = torch.arange(self.window_size[0])
|
|
526
|
+
coords_w = torch.arange(self.window_size[1])
|
|
527
|
+
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
|
528
|
+
coords_flatten = torch.flatten(coords, 1)
|
|
529
|
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
530
|
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
|
531
|
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
|
532
|
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
|
533
|
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
|
534
|
+
relative_position_index = relative_coords.sum(-1)
|
|
535
|
+
|
|
536
|
+
return relative_coords_table, relative_position_index
|
|
537
|
+
|
|
533
538
|
|
|
534
539
|
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->Swinv2
|
|
535
540
|
class Swinv2SelfOutput(nn.Module):
|
|
@@ -904,6 +909,9 @@ class Swinv2PreTrainedModel(PreTrainedModel):
|
|
|
904
909
|
init.zeros_(module.position_embeddings)
|
|
905
910
|
elif isinstance(module, Swinv2SelfAttention):
|
|
906
911
|
init.constant_(module.logit_scale, math.log(10))
|
|
912
|
+
relative_coords_table, relative_position_index = module.create_coords_table_and_index()
|
|
913
|
+
init.copy_(module.relative_coords_table, relative_coords_table)
|
|
914
|
+
init.copy_(module.relative_position_index, relative_position_index)
|
|
907
915
|
|
|
908
916
|
|
|
909
917
|
@auto_docstring
|
|
@@ -111,7 +111,7 @@ class SwitchTransformersTop1Router(nn.Module):
|
|
|
111
111
|
router_logits, expert_index = torch.max(router_probs, dim=-1, keepdim=True)
|
|
112
112
|
expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
|
|
113
113
|
token_priority = torch.cumsum(expert_index, dim=-2)
|
|
114
|
-
# mask if the token routed to
|
|
114
|
+
# mask if the token routed to the expert will overflow
|
|
115
115
|
expert_capacity_mask = token_priority <= self.expert_capacity
|
|
116
116
|
expert_index = expert_index * expert_capacity_mask
|
|
117
117
|
router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
|
|
@@ -913,6 +913,7 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
|
|
|
913
913
|
"encoder.embed_tokens.weight": "shared.weight",
|
|
914
914
|
"decoder.embed_tokens.weight": "shared.weight",
|
|
915
915
|
}
|
|
916
|
+
_input_embed_layer = "shared"
|
|
916
917
|
|
|
917
918
|
def __init__(self, config: SwitchTransformersConfig):
|
|
918
919
|
super().__init__(config)
|
|
@@ -921,20 +922,15 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
|
|
|
921
922
|
encoder_config = copy.deepcopy(config)
|
|
922
923
|
encoder_config.is_decoder = False
|
|
923
924
|
encoder_config.use_cache = False
|
|
924
|
-
encoder_config.tie_encoder_decoder = False
|
|
925
925
|
self.encoder = SwitchTransformersStack(encoder_config)
|
|
926
926
|
|
|
927
927
|
decoder_config = copy.deepcopy(config)
|
|
928
928
|
decoder_config.is_decoder = True
|
|
929
|
-
decoder_config.tie_encoder_decoder = False
|
|
930
929
|
self.decoder = SwitchTransformersStack(decoder_config)
|
|
931
930
|
|
|
932
931
|
# Initialize weights and apply final processing
|
|
933
932
|
self.post_init()
|
|
934
933
|
|
|
935
|
-
def get_input_embeddings(self):
|
|
936
|
-
return self.shared
|
|
937
|
-
|
|
938
934
|
def set_input_embeddings(self, new_embeddings):
|
|
939
935
|
self.shared = new_embeddings
|
|
940
936
|
self.encoder.set_input_embeddings(new_embeddings)
|
|
@@ -1072,12 +1068,10 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
|
|
|
1072
1068
|
encoder_config = copy.deepcopy(config)
|
|
1073
1069
|
encoder_config.is_decoder = False
|
|
1074
1070
|
encoder_config.use_cache = False
|
|
1075
|
-
encoder_config.tie_encoder_decoder = False
|
|
1076
1071
|
self.encoder = SwitchTransformersStack(encoder_config)
|
|
1077
1072
|
|
|
1078
1073
|
decoder_config = copy.deepcopy(config)
|
|
1079
1074
|
decoder_config.is_decoder = True
|
|
1080
|
-
decoder_config.tie_encoder_decoder = False
|
|
1081
1075
|
decoder_config.num_layers = config.num_decoder_layers
|
|
1082
1076
|
self.decoder = SwitchTransformersStack(decoder_config)
|
|
1083
1077
|
|
|
@@ -170,7 +170,7 @@ class SwitchTransformersTop1Router(nn.Module):
|
|
|
170
170
|
router_logits, expert_index = torch.max(router_probs, dim=-1, keepdim=True)
|
|
171
171
|
expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
|
|
172
172
|
token_priority = torch.cumsum(expert_index, dim=-2)
|
|
173
|
-
# mask if the token routed to
|
|
173
|
+
# mask if the token routed to the expert will overflow
|
|
174
174
|
expert_capacity_mask = token_priority <= self.expert_capacity
|
|
175
175
|
expert_index = expert_index * expert_capacity_mask
|
|
176
176
|
router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
|
|
@@ -669,6 +669,7 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
|
|
|
669
669
|
"encoder.embed_tokens.weight": "shared.weight",
|
|
670
670
|
"decoder.embed_tokens.weight": "shared.weight",
|
|
671
671
|
}
|
|
672
|
+
_input_embed_layer = "shared"
|
|
672
673
|
|
|
673
674
|
def __init__(self, config: SwitchTransformersConfig):
|
|
674
675
|
super().__init__(config)
|
|
@@ -677,20 +678,15 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
|
|
|
677
678
|
encoder_config = copy.deepcopy(config)
|
|
678
679
|
encoder_config.is_decoder = False
|
|
679
680
|
encoder_config.use_cache = False
|
|
680
|
-
encoder_config.tie_encoder_decoder = False
|
|
681
681
|
self.encoder = SwitchTransformersStack(encoder_config)
|
|
682
682
|
|
|
683
683
|
decoder_config = copy.deepcopy(config)
|
|
684
684
|
decoder_config.is_decoder = True
|
|
685
|
-
decoder_config.tie_encoder_decoder = False
|
|
686
685
|
self.decoder = SwitchTransformersStack(decoder_config)
|
|
687
686
|
|
|
688
687
|
# Initialize weights and apply final processing
|
|
689
688
|
self.post_init()
|
|
690
689
|
|
|
691
|
-
def get_input_embeddings(self):
|
|
692
|
-
return self.shared
|
|
693
|
-
|
|
694
690
|
def set_input_embeddings(self, new_embeddings):
|
|
695
691
|
self.shared = new_embeddings
|
|
696
692
|
self.encoder.set_input_embeddings(new_embeddings)
|
|
@@ -763,12 +759,10 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
|
|
|
763
759
|
encoder_config = copy.deepcopy(config)
|
|
764
760
|
encoder_config.is_decoder = False
|
|
765
761
|
encoder_config.use_cache = False
|
|
766
|
-
encoder_config.tie_encoder_decoder = False
|
|
767
762
|
self.encoder = SwitchTransformersStack(encoder_config)
|
|
768
763
|
|
|
769
764
|
decoder_config = copy.deepcopy(config)
|
|
770
765
|
decoder_config.is_decoder = True
|
|
771
|
-
decoder_config.tie_encoder_decoder = False
|
|
772
766
|
decoder_config.num_layers = config.num_decoder_layers
|
|
773
767
|
self.decoder = SwitchTransformersStack(decoder_config)
|
|
774
768
|
|
|
@@ -131,13 +131,19 @@ class T5Config(PreTrainedConfig):
|
|
|
131
131
|
if feed_forward_proj == "gated-gelu":
|
|
132
132
|
self.dense_act_fn = "gelu_new"
|
|
133
133
|
|
|
134
|
+
# Super weird feature of T5 because we support T5 and T51.1 from the same
|
|
135
|
+
# model code. Original T5 always scaled outputs, but the 1.1v does not.
|
|
136
|
+
# The model code was relying on saved configs where `tie_word_embeddings` is
|
|
137
|
+
# set to `False` in 1.1v and using it as indicator of whether to scale or not
|
|
138
|
+
# But in fact we tie weights always and force it to be `True`
|
|
139
|
+
self.scale_decoder_outputs = kwargs.get("tie_word_embeddings") is not False
|
|
140
|
+
kwargs["tie_word_embeddings"] = True
|
|
134
141
|
super().__init__(
|
|
135
142
|
pad_token_id=pad_token_id,
|
|
136
143
|
eos_token_id=eos_token_id,
|
|
137
144
|
is_encoder_decoder=is_encoder_decoder,
|
|
138
145
|
**kwargs,
|
|
139
146
|
)
|
|
140
|
-
self.tie_encoder_decoder = True # T5 is always tied, has always been like that.
|
|
141
147
|
|
|
142
148
|
|
|
143
149
|
__all__ = ["T5Config"]
|
|
@@ -844,12 +844,10 @@ class T5Model(T5PreTrainedModel):
|
|
|
844
844
|
encoder_config = copy.deepcopy(config)
|
|
845
845
|
encoder_config.is_decoder = False
|
|
846
846
|
encoder_config.use_cache = False
|
|
847
|
-
encoder_config.tie_encoder_decoder = False
|
|
848
847
|
self.encoder = T5Stack(encoder_config)
|
|
849
848
|
|
|
850
849
|
decoder_config = copy.deepcopy(config)
|
|
851
850
|
decoder_config.is_decoder = True
|
|
852
|
-
decoder_config.tie_encoder_decoder = False
|
|
853
851
|
decoder_config.num_layers = config.num_decoder_layers
|
|
854
852
|
self.decoder = T5Stack(decoder_config)
|
|
855
853
|
|
|
@@ -1007,12 +1005,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
|
|
|
1007
1005
|
encoder_config = copy.deepcopy(config)
|
|
1008
1006
|
encoder_config.is_decoder = False
|
|
1009
1007
|
encoder_config.use_cache = False
|
|
1010
|
-
encoder_config.tie_encoder_decoder = False
|
|
1011
1008
|
self.encoder = T5Stack(encoder_config)
|
|
1012
1009
|
|
|
1013
1010
|
decoder_config = copy.deepcopy(config)
|
|
1014
1011
|
decoder_config.is_decoder = True
|
|
1015
|
-
decoder_config.tie_encoder_decoder = False
|
|
1016
1012
|
decoder_config.num_layers = config.num_decoder_layers
|
|
1017
1013
|
self.decoder = T5Stack(decoder_config)
|
|
1018
1014
|
|
|
@@ -1147,7 +1143,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
|
|
|
1147
1143
|
|
|
1148
1144
|
sequence_output = decoder_outputs[0]
|
|
1149
1145
|
|
|
1150
|
-
if self.config.
|
|
1146
|
+
if self.config.scale_decoder_outputs:
|
|
1151
1147
|
sequence_output = sequence_output * (self.model_dim**-0.5)
|
|
1152
1148
|
|
|
1153
1149
|
lm_logits = self.lm_head(sequence_output)
|
|
@@ -1487,12 +1483,10 @@ class T5ForQuestionAnswering(T5PreTrainedModel):
|
|
|
1487
1483
|
encoder_config = copy.deepcopy(config)
|
|
1488
1484
|
encoder_config.is_decoder = False
|
|
1489
1485
|
encoder_config.use_cache = False
|
|
1490
|
-
encoder_config.tie_encoder_decoder = False
|
|
1491
1486
|
self.encoder = T5Stack(encoder_config)
|
|
1492
1487
|
|
|
1493
1488
|
decoder_config = copy.deepcopy(config)
|
|
1494
1489
|
decoder_config.is_decoder = True
|
|
1495
|
-
decoder_config.tie_encoder_decoder = False
|
|
1496
1490
|
decoder_config.num_layers = config.num_decoder_layers
|
|
1497
1491
|
self.decoder = T5Stack(decoder_config)
|
|
1498
1492
|
|
|
@@ -108,7 +108,7 @@ class T5GemmaRotaryEmbedding(nn.Module):
|
|
|
108
108
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
109
109
|
|
|
110
110
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
111
|
-
self.original_inv_freq =
|
|
111
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
112
112
|
|
|
113
113
|
@staticmethod
|
|
114
114
|
def compute_default_rope_parameters(
|
|
@@ -32,9 +32,9 @@ logger = logging.get_logger(__name__)
|
|
|
32
32
|
|
|
33
33
|
class T5Gemma2TextConfig(PreTrainedConfig):
|
|
34
34
|
r"""
|
|
35
|
-
This is the configuration class to store the configuration of a [`T5Gemma2TextModel`]. It is used to instantiate
|
|
36
|
-
model according to the specified arguments, defining the model architecture. Instantiating
|
|
37
|
-
defaults will yield a similar configuration to that of the T5Gemma2Text-7B.
|
|
35
|
+
This is the configuration class to store the configuration of a [`T5Gemma2TextModel`]. It is used to instantiate the encoder's
|
|
36
|
+
text model portion of the T5Gemma2 Model according to the specified arguments, defining the model architecture. Instantiating
|
|
37
|
+
a configuration with the defaults will yield a similar configuration to that of the T5Gemma2Text-7B.
|
|
38
38
|
e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b)
|
|
39
39
|
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
|
40
40
|
documentation from [`PreTrainedConfig`] for more information.
|
|
@@ -99,19 +99,6 @@ class T5Gemma2TextConfig(PreTrainedConfig):
|
|
|
99
99
|
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
|
|
100
100
|
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
|
|
101
101
|
with longer `max_position_embeddings`.
|
|
102
|
-
use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
|
|
103
|
-
If True, the model will attend to all text tokens instead of using a causal mask. This does not change
|
|
104
|
-
behavior for vision tokens.
|
|
105
|
-
|
|
106
|
-
```python
|
|
107
|
-
>>> from transformers import T5Gemma2TextModel, T5Gemma2TextConfig
|
|
108
|
-
>>> # Initializing a T5Gemma2Text t5gemma2_text-7b style configuration
|
|
109
|
-
>>> configuration = T5Gemma2TextConfig()
|
|
110
|
-
>>> # Initializing a model from the t5gemma2_text-7b style configuration
|
|
111
|
-
>>> model = T5Gemma2TextModel(configuration)
|
|
112
|
-
>>> # Accessing the model configuration
|
|
113
|
-
>>> configuration = model.config
|
|
114
|
-
```
|
|
115
102
|
"""
|
|
116
103
|
|
|
117
104
|
model_type = "t5gemma2_text"
|
|
@@ -158,7 +145,6 @@ class T5Gemma2TextConfig(PreTrainedConfig):
|
|
|
158
145
|
final_logit_softcapping: Optional[float] = None,
|
|
159
146
|
attn_logit_softcapping: Optional[float] = None,
|
|
160
147
|
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
|
161
|
-
use_bidirectional_attention: Optional[bool] = False,
|
|
162
148
|
**kwargs,
|
|
163
149
|
):
|
|
164
150
|
self.vocab_size = vocab_size
|
|
@@ -181,10 +167,6 @@ class T5Gemma2TextConfig(PreTrainedConfig):
|
|
|
181
167
|
self.attn_logit_softcapping = attn_logit_softcapping
|
|
182
168
|
self.layer_types = layer_types
|
|
183
169
|
|
|
184
|
-
self.use_bidirectional_attention = use_bidirectional_attention
|
|
185
|
-
if use_bidirectional_attention:
|
|
186
|
-
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
|
|
187
|
-
|
|
188
170
|
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
|
189
171
|
self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
|
|
190
172
|
|
|
@@ -326,9 +308,9 @@ class T5Gemma2EncoderConfig(PreTrainedConfig):
|
|
|
326
308
|
|
|
327
309
|
class T5Gemma2DecoderConfig(PreTrainedConfig):
|
|
328
310
|
r"""
|
|
329
|
-
This is the configuration class to store the configuration of a [`T5Gemma2DecoderModel`]. It is used to instantiate
|
|
330
|
-
model according to the specified arguments, defining the model architecture. Instantiating
|
|
331
|
-
defaults will yield a similar configuration to that of the T5Gemma2Decoder-7B.
|
|
311
|
+
This is the configuration class to store the configuration of a [`T5Gemma2DecoderModel`]. It is used to instantiate the decoder
|
|
312
|
+
text model portion of the T5Gemma2 Model according to the specified arguments, defining the model architecture. Instantiating
|
|
313
|
+
a configuration with the defaults will yield a similar configuration to that of the T5Gemma2Decoder-7B.
|
|
332
314
|
e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b)
|
|
333
315
|
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
|
334
316
|
documentation from [`PreTrainedConfig`] for more information.
|
|
@@ -393,19 +375,6 @@ class T5Gemma2DecoderConfig(PreTrainedConfig):
|
|
|
393
375
|
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
|
|
394
376
|
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
|
|
395
377
|
with longer `max_position_embeddings`.
|
|
396
|
-
use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
|
|
397
|
-
If True, the model will attend to all text tokens instead of using a causal mask. This does not change
|
|
398
|
-
behavior for vision tokens.
|
|
399
|
-
|
|
400
|
-
```python
|
|
401
|
-
>>> from transformers import T5Gemma2DecoderModel, T5Gemma2DecoderConfig
|
|
402
|
-
>>> # Initializing a T5Gemma2Decoder t5gemma2_text-7b style configuration
|
|
403
|
-
>>> configuration = T5Gemma2DecoderConfig()
|
|
404
|
-
>>> # Initializing a model from the t5gemma2_text-7b style configuration
|
|
405
|
-
>>> model = T5Gemma2DecoderModel(configuration)
|
|
406
|
-
>>> # Accessing the model configuration
|
|
407
|
-
>>> configuration = model.config
|
|
408
|
-
```
|
|
409
378
|
"""
|
|
410
379
|
|
|
411
380
|
model_type = "t5gemma2_decoder"
|
|
@@ -452,7 +421,6 @@ class T5Gemma2DecoderConfig(PreTrainedConfig):
|
|
|
452
421
|
final_logit_softcapping: Optional[float] = None,
|
|
453
422
|
attn_logit_softcapping: Optional[float] = None,
|
|
454
423
|
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
|
455
|
-
use_bidirectional_attention: Optional[bool] = False,
|
|
456
424
|
**kwargs,
|
|
457
425
|
):
|
|
458
426
|
self.vocab_size = vocab_size
|
|
@@ -475,10 +443,6 @@ class T5Gemma2DecoderConfig(PreTrainedConfig):
|
|
|
475
443
|
self.attn_logit_softcapping = attn_logit_softcapping
|
|
476
444
|
self.layer_types = layer_types
|
|
477
445
|
|
|
478
|
-
self.use_bidirectional_attention = use_bidirectional_attention
|
|
479
|
-
if use_bidirectional_attention:
|
|
480
|
-
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
|
|
481
|
-
|
|
482
446
|
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
|
483
447
|
self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
|
|
484
448
|
|
|
@@ -113,7 +113,7 @@ class T5Gemma2RotaryEmbedding(nn.Module):
|
|
|
113
113
|
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
|
|
114
114
|
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
|
|
115
115
|
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
|
|
116
|
-
|
|
116
|
+
self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
|
|
117
117
|
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
|
|
118
118
|
|
|
119
119
|
@staticmethod
|
|
@@ -266,7 +266,7 @@ class T5Gemma2SelfAttention(nn.Module):
|
|
|
266
266
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
|
267
267
|
self.scaling = config.query_pre_attn_scalar**-0.5
|
|
268
268
|
self.attention_dropout = self.config.attention_dropout
|
|
269
|
-
self.is_causal =
|
|
269
|
+
self.is_causal = False # Only used by the encoder
|
|
270
270
|
|
|
271
271
|
self.q_proj = nn.Linear(
|
|
272
272
|
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
|
@@ -348,7 +348,7 @@ class T5Gemma2MergedAttention(nn.Module):
|
|
|
348
348
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
|
349
349
|
self.scaling = config.query_pre_attn_scalar**-0.5
|
|
350
350
|
self.attention_dropout = self.config.attention_dropout
|
|
351
|
-
self.is_causal =
|
|
351
|
+
self.is_causal = False # Fused causal and encoder mask
|
|
352
352
|
|
|
353
353
|
self.q_proj = nn.Linear(
|
|
354
354
|
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
|
@@ -446,7 +446,6 @@ class T5Gemma2MergedAttention(nn.Module):
|
|
|
446
446
|
merged_attention_mask,
|
|
447
447
|
dropout=self.attention_dropout if self.training else 0.0,
|
|
448
448
|
scaling=self.scaling,
|
|
449
|
-
is_causal=False,
|
|
450
449
|
**kwargs,
|
|
451
450
|
)
|
|
452
451
|
|
|
@@ -649,6 +648,7 @@ class T5Gemma2TextScaledWordEmbedding(nn.Embedding):
|
|
|
649
648
|
eoi_token_index: int = 256_000,
|
|
650
649
|
):
|
|
651
650
|
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
|
651
|
+
self.scalar_embed_scale = embed_scale
|
|
652
652
|
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
|
653
653
|
self.eoi_token_index = eoi_token_index
|
|
654
654
|
self.eoi_embedding = nn.Parameter(torch.zeros(self.embedding_dim))
|
|
@@ -700,6 +700,7 @@ class T5Gemma2PreTrainedModel(PreTrainedModel):
|
|
|
700
700
|
init.zeros_(module.mm_input_projection_weight)
|
|
701
701
|
elif isinstance(module, T5Gemma2TextScaledWordEmbedding):
|
|
702
702
|
init.zeros_(module.eoi_embedding)
|
|
703
|
+
init.constant_(module.embed_scale, module.scalar_embed_scale)
|
|
703
704
|
elif isinstance(module, T5Gemma2ClassificationHead):
|
|
704
705
|
scale = module.out_proj.weight.shape[0] ** -0.5
|
|
705
706
|
init.normal_(module.out_proj.weight, mean=0.0, std=self.config.initializer_range * scale)
|
|
@@ -708,6 +709,14 @@ class T5Gemma2PreTrainedModel(PreTrainedModel):
|
|
|
708
709
|
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
|
709
710
|
elif "RMSNorm" in module.__class__.__name__:
|
|
710
711
|
init.zeros_(module.weight)
|
|
712
|
+
elif isinstance(module, T5Gemma2RotaryEmbedding):
|
|
713
|
+
for layer_type in module.layer_types:
|
|
714
|
+
rope_init_fn = module.compute_default_rope_parameters
|
|
715
|
+
if module.rope_type[layer_type] != "default":
|
|
716
|
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
|
717
|
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
718
|
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
719
|
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
711
720
|
|
|
712
721
|
def prepare_decoder_input_ids_from_labels(self, input_ids):
|
|
713
722
|
"""
|