transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -111,7 +111,7 @@ class SwitchTransformersTop1Router(nn.Module):
|
|
|
111
111
|
router_logits, expert_index = torch.max(router_probs, dim=-1, keepdim=True)
|
|
112
112
|
expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
|
|
113
113
|
token_priority = torch.cumsum(expert_index, dim=-2)
|
|
114
|
-
# mask if the token routed to
|
|
114
|
+
# mask if the token routed to the expert will overflow
|
|
115
115
|
expert_capacity_mask = token_priority <= self.expert_capacity
|
|
116
116
|
expert_index = expert_index * expert_capacity_mask
|
|
117
117
|
router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
|
|
@@ -913,6 +913,7 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
|
|
|
913
913
|
"encoder.embed_tokens.weight": "shared.weight",
|
|
914
914
|
"decoder.embed_tokens.weight": "shared.weight",
|
|
915
915
|
}
|
|
916
|
+
_input_embed_layer = "shared"
|
|
916
917
|
|
|
917
918
|
def __init__(self, config: SwitchTransformersConfig):
|
|
918
919
|
super().__init__(config)
|
|
@@ -921,20 +922,15 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
|
|
|
921
922
|
encoder_config = copy.deepcopy(config)
|
|
922
923
|
encoder_config.is_decoder = False
|
|
923
924
|
encoder_config.use_cache = False
|
|
924
|
-
encoder_config.tie_encoder_decoder = False
|
|
925
925
|
self.encoder = SwitchTransformersStack(encoder_config)
|
|
926
926
|
|
|
927
927
|
decoder_config = copy.deepcopy(config)
|
|
928
928
|
decoder_config.is_decoder = True
|
|
929
|
-
decoder_config.tie_encoder_decoder = False
|
|
930
929
|
self.decoder = SwitchTransformersStack(decoder_config)
|
|
931
930
|
|
|
932
931
|
# Initialize weights and apply final processing
|
|
933
932
|
self.post_init()
|
|
934
933
|
|
|
935
|
-
def get_input_embeddings(self):
|
|
936
|
-
return self.shared
|
|
937
|
-
|
|
938
934
|
def set_input_embeddings(self, new_embeddings):
|
|
939
935
|
self.shared = new_embeddings
|
|
940
936
|
self.encoder.set_input_embeddings(new_embeddings)
|
|
@@ -1072,12 +1068,10 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
|
|
|
1072
1068
|
encoder_config = copy.deepcopy(config)
|
|
1073
1069
|
encoder_config.is_decoder = False
|
|
1074
1070
|
encoder_config.use_cache = False
|
|
1075
|
-
encoder_config.tie_encoder_decoder = False
|
|
1076
1071
|
self.encoder = SwitchTransformersStack(encoder_config)
|
|
1077
1072
|
|
|
1078
1073
|
decoder_config = copy.deepcopy(config)
|
|
1079
1074
|
decoder_config.is_decoder = True
|
|
1080
|
-
decoder_config.tie_encoder_decoder = False
|
|
1081
1075
|
decoder_config.num_layers = config.num_decoder_layers
|
|
1082
1076
|
self.decoder = SwitchTransformersStack(decoder_config)
|
|
1083
1077
|
|
|
@@ -170,7 +170,7 @@ class SwitchTransformersTop1Router(nn.Module):
|
|
|
170
170
|
router_logits, expert_index = torch.max(router_probs, dim=-1, keepdim=True)
|
|
171
171
|
expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
|
|
172
172
|
token_priority = torch.cumsum(expert_index, dim=-2)
|
|
173
|
-
# mask if the token routed to
|
|
173
|
+
# mask if the token routed to the expert will overflow
|
|
174
174
|
expert_capacity_mask = token_priority <= self.expert_capacity
|
|
175
175
|
expert_index = expert_index * expert_capacity_mask
|
|
176
176
|
router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
|
|
@@ -669,6 +669,7 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
|
|
|
669
669
|
"encoder.embed_tokens.weight": "shared.weight",
|
|
670
670
|
"decoder.embed_tokens.weight": "shared.weight",
|
|
671
671
|
}
|
|
672
|
+
_input_embed_layer = "shared"
|
|
672
673
|
|
|
673
674
|
def __init__(self, config: SwitchTransformersConfig):
|
|
674
675
|
super().__init__(config)
|
|
@@ -677,20 +678,15 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
|
|
|
677
678
|
encoder_config = copy.deepcopy(config)
|
|
678
679
|
encoder_config.is_decoder = False
|
|
679
680
|
encoder_config.use_cache = False
|
|
680
|
-
encoder_config.tie_encoder_decoder = False
|
|
681
681
|
self.encoder = SwitchTransformersStack(encoder_config)
|
|
682
682
|
|
|
683
683
|
decoder_config = copy.deepcopy(config)
|
|
684
684
|
decoder_config.is_decoder = True
|
|
685
|
-
decoder_config.tie_encoder_decoder = False
|
|
686
685
|
self.decoder = SwitchTransformersStack(decoder_config)
|
|
687
686
|
|
|
688
687
|
# Initialize weights and apply final processing
|
|
689
688
|
self.post_init()
|
|
690
689
|
|
|
691
|
-
def get_input_embeddings(self):
|
|
692
|
-
return self.shared
|
|
693
|
-
|
|
694
690
|
def set_input_embeddings(self, new_embeddings):
|
|
695
691
|
self.shared = new_embeddings
|
|
696
692
|
self.encoder.set_input_embeddings(new_embeddings)
|
|
@@ -763,12 +759,10 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
|
|
|
763
759
|
encoder_config = copy.deepcopy(config)
|
|
764
760
|
encoder_config.is_decoder = False
|
|
765
761
|
encoder_config.use_cache = False
|
|
766
|
-
encoder_config.tie_encoder_decoder = False
|
|
767
762
|
self.encoder = SwitchTransformersStack(encoder_config)
|
|
768
763
|
|
|
769
764
|
decoder_config = copy.deepcopy(config)
|
|
770
765
|
decoder_config.is_decoder = True
|
|
771
|
-
decoder_config.tie_encoder_decoder = False
|
|
772
766
|
decoder_config.num_layers = config.num_decoder_layers
|
|
773
767
|
self.decoder = SwitchTransformersStack(decoder_config)
|
|
774
768
|
|
|
@@ -131,13 +131,19 @@ class T5Config(PreTrainedConfig):
|
|
|
131
131
|
if feed_forward_proj == "gated-gelu":
|
|
132
132
|
self.dense_act_fn = "gelu_new"
|
|
133
133
|
|
|
134
|
+
# Super weird feature of T5 because we support T5 and T51.1 from the same
|
|
135
|
+
# model code. Original T5 always scaled outputs, but the 1.1v does not.
|
|
136
|
+
# The model code was relying on saved configs where `tie_word_embeddings` is
|
|
137
|
+
# set to `False` in 1.1v and using it as indicator of whether to scale or not
|
|
138
|
+
# But in fact we tie weights always and force it to be `True`
|
|
139
|
+
self.scale_decoder_outputs = kwargs.get("tie_word_embeddings") is not False
|
|
140
|
+
kwargs["tie_word_embeddings"] = True
|
|
134
141
|
super().__init__(
|
|
135
142
|
pad_token_id=pad_token_id,
|
|
136
143
|
eos_token_id=eos_token_id,
|
|
137
144
|
is_encoder_decoder=is_encoder_decoder,
|
|
138
145
|
**kwargs,
|
|
139
146
|
)
|
|
140
|
-
self.tie_encoder_decoder = True # T5 is always tied, has always been like that.
|
|
141
147
|
|
|
142
148
|
|
|
143
149
|
__all__ = ["T5Config"]
|
|
@@ -673,6 +673,7 @@ class T5Stack(T5PreTrainedModel):
|
|
|
673
673
|
output_hidden_states=None,
|
|
674
674
|
return_dict=None,
|
|
675
675
|
cache_position=None,
|
|
676
|
+
**kwargs,
|
|
676
677
|
):
|
|
677
678
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
678
679
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
@@ -843,12 +844,10 @@ class T5Model(T5PreTrainedModel):
|
|
|
843
844
|
encoder_config = copy.deepcopy(config)
|
|
844
845
|
encoder_config.is_decoder = False
|
|
845
846
|
encoder_config.use_cache = False
|
|
846
|
-
encoder_config.tie_encoder_decoder = False
|
|
847
847
|
self.encoder = T5Stack(encoder_config)
|
|
848
848
|
|
|
849
849
|
decoder_config = copy.deepcopy(config)
|
|
850
850
|
decoder_config.is_decoder = True
|
|
851
|
-
decoder_config.tie_encoder_decoder = False
|
|
852
851
|
decoder_config.num_layers = config.num_decoder_layers
|
|
853
852
|
self.decoder = T5Stack(decoder_config)
|
|
854
853
|
|
|
@@ -879,6 +878,7 @@ class T5Model(T5PreTrainedModel):
|
|
|
879
878
|
output_hidden_states: Optional[bool] = None,
|
|
880
879
|
return_dict: Optional[bool] = None,
|
|
881
880
|
cache_position: Optional[torch.LongTensor] = None,
|
|
881
|
+
**kwargs,
|
|
882
882
|
) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
|
|
883
883
|
r"""
|
|
884
884
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -1005,12 +1005,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
|
|
|
1005
1005
|
encoder_config = copy.deepcopy(config)
|
|
1006
1006
|
encoder_config.is_decoder = False
|
|
1007
1007
|
encoder_config.use_cache = False
|
|
1008
|
-
encoder_config.tie_encoder_decoder = False
|
|
1009
1008
|
self.encoder = T5Stack(encoder_config)
|
|
1010
1009
|
|
|
1011
1010
|
decoder_config = copy.deepcopy(config)
|
|
1012
1011
|
decoder_config.is_decoder = True
|
|
1013
|
-
decoder_config.tie_encoder_decoder = False
|
|
1014
1012
|
decoder_config.num_layers = config.num_decoder_layers
|
|
1015
1013
|
self.decoder = T5Stack(decoder_config)
|
|
1016
1014
|
|
|
@@ -1044,6 +1042,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
|
|
|
1044
1042
|
output_hidden_states: Optional[bool] = None,
|
|
1045
1043
|
return_dict: Optional[bool] = None,
|
|
1046
1044
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1045
|
+
**kwargs,
|
|
1047
1046
|
) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
|
1048
1047
|
r"""
|
|
1049
1048
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -1144,7 +1143,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
|
|
|
1144
1143
|
|
|
1145
1144
|
sequence_output = decoder_outputs[0]
|
|
1146
1145
|
|
|
1147
|
-
if self.config.
|
|
1146
|
+
if self.config.scale_decoder_outputs:
|
|
1148
1147
|
sequence_output = sequence_output * (self.model_dim**-0.5)
|
|
1149
1148
|
|
|
1150
1149
|
lm_logits = self.lm_head(sequence_output)
|
|
@@ -1209,6 +1208,7 @@ class T5EncoderModel(T5PreTrainedModel):
|
|
|
1209
1208
|
output_attentions: Optional[bool] = None,
|
|
1210
1209
|
output_hidden_states: Optional[bool] = None,
|
|
1211
1210
|
return_dict: Optional[bool] = None,
|
|
1211
|
+
**kwargs,
|
|
1212
1212
|
) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
|
|
1213
1213
|
r"""
|
|
1214
1214
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -1279,6 +1279,7 @@ class T5ForSequenceClassification(T5PreTrainedModel):
|
|
|
1279
1279
|
output_attentions: Optional[bool] = None,
|
|
1280
1280
|
output_hidden_states: Optional[bool] = None,
|
|
1281
1281
|
return_dict: Optional[bool] = None,
|
|
1282
|
+
**kwargs,
|
|
1282
1283
|
) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
|
|
1283
1284
|
r"""
|
|
1284
1285
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -1417,6 +1418,7 @@ class T5ForTokenClassification(T5PreTrainedModel):
|
|
|
1417
1418
|
output_attentions: Optional[bool] = None,
|
|
1418
1419
|
output_hidden_states: Optional[bool] = None,
|
|
1419
1420
|
return_dict: Optional[bool] = None,
|
|
1421
|
+
**kwargs,
|
|
1420
1422
|
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
|
1421
1423
|
r"""
|
|
1422
1424
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -1481,12 +1483,10 @@ class T5ForQuestionAnswering(T5PreTrainedModel):
|
|
|
1481
1483
|
encoder_config = copy.deepcopy(config)
|
|
1482
1484
|
encoder_config.is_decoder = False
|
|
1483
1485
|
encoder_config.use_cache = False
|
|
1484
|
-
encoder_config.tie_encoder_decoder = False
|
|
1485
1486
|
self.encoder = T5Stack(encoder_config)
|
|
1486
1487
|
|
|
1487
1488
|
decoder_config = copy.deepcopy(config)
|
|
1488
1489
|
decoder_config.is_decoder = True
|
|
1489
|
-
decoder_config.tie_encoder_decoder = False
|
|
1490
1490
|
decoder_config.num_layers = config.num_decoder_layers
|
|
1491
1491
|
self.decoder = T5Stack(decoder_config)
|
|
1492
1492
|
|
|
@@ -1520,6 +1520,7 @@ class T5ForQuestionAnswering(T5PreTrainedModel):
|
|
|
1520
1520
|
output_attentions: Optional[bool] = None,
|
|
1521
1521
|
output_hidden_states: Optional[bool] = None,
|
|
1522
1522
|
return_dict: Optional[bool] = None,
|
|
1523
|
+
**kwargs,
|
|
1523
1524
|
) -> Union[tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
|
|
1524
1525
|
r"""
|
|
1525
1526
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Tokenization class for model T5."""
|
|
16
16
|
|
|
17
17
|
import re
|
|
18
|
+
from typing import Optional, Union
|
|
18
19
|
|
|
19
20
|
from tokenizers import Tokenizer, decoders, pre_tokenizers, processors
|
|
20
21
|
from tokenizers.models import Unigram
|
|
@@ -61,26 +62,24 @@ class T5Tokenizer(TokenizersBackend):
|
|
|
61
62
|
calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids method
|
|
62
63
|
additional_special_tokens (`list[str]`, *optional*):
|
|
63
64
|
Additional special tokens used by the tokenizer.
|
|
64
|
-
vocab (`dict`, *optional*):
|
|
65
|
+
vocab (`str`, `dict` or `list`, *optional*):
|
|
65
66
|
Custom vocabulary dict. If not provided, a minimal vocabulary is created using the special tokens.
|
|
66
67
|
"""
|
|
67
68
|
|
|
68
69
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
69
70
|
model_input_names = ["input_ids", "attention_mask"]
|
|
70
|
-
|
|
71
|
+
model = Unigram
|
|
71
72
|
|
|
72
73
|
def __init__(
|
|
73
74
|
self,
|
|
75
|
+
vocab: Optional[Union[str, list[tuple[str, float]]]] = None,
|
|
74
76
|
eos_token="</s>",
|
|
75
77
|
unk_token="<unk>",
|
|
76
78
|
pad_token="<pad>",
|
|
77
79
|
extra_ids=100,
|
|
78
80
|
additional_special_tokens=None,
|
|
79
|
-
vocab=None,
|
|
80
|
-
vocab_file=None,
|
|
81
81
|
**kwargs,
|
|
82
82
|
):
|
|
83
|
-
self.vocab_file = vocab_file
|
|
84
83
|
self._extra_ids = extra_ids
|
|
85
84
|
|
|
86
85
|
# Handle extra_ids and additional_special_tokens
|
|
@@ -130,10 +129,7 @@ class T5Tokenizer(TokenizersBackend):
|
|
|
130
129
|
|
|
131
130
|
self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
|
|
132
131
|
|
|
133
|
-
tokenizer_object = self._tokenizer
|
|
134
|
-
|
|
135
132
|
super().__init__(
|
|
136
|
-
tokenizer_object=tokenizer_object,
|
|
137
133
|
eos_token=eos_token,
|
|
138
134
|
unk_token=unk_token,
|
|
139
135
|
pad_token=pad_token,
|
|
@@ -29,7 +29,7 @@ from ... import initialization as init
|
|
|
29
29
|
from ...activations import ACT2FN
|
|
30
30
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
31
31
|
from ...generation import GenerationMixin
|
|
32
|
-
from ...integrations import use_kernel_func_from_hub
|
|
32
|
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
|
33
33
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
34
34
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
35
35
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -45,7 +45,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
45
45
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
46
46
|
from ...processing_utils import Unpack
|
|
47
47
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
48
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
48
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
49
49
|
from .configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
|
|
50
50
|
|
|
51
51
|
|
|
@@ -108,7 +108,7 @@ class T5GemmaRotaryEmbedding(nn.Module):
|
|
|
108
108
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
109
109
|
|
|
110
110
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
111
|
-
self.original_inv_freq =
|
|
111
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
112
112
|
|
|
113
113
|
@staticmethod
|
|
114
114
|
def compute_default_rope_parameters(
|
|
@@ -147,7 +147,7 @@ class T5GemmaRotaryEmbedding(nn.Module):
|
|
|
147
147
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
148
148
|
|
|
149
149
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
150
|
-
with
|
|
150
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
151
151
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
152
152
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
153
153
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -238,6 +238,7 @@ def eager_attention_forward(
|
|
|
238
238
|
return attn_output, attn_weights
|
|
239
239
|
|
|
240
240
|
|
|
241
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
241
242
|
class T5GemmaSelfAttention(nn.Module):
|
|
242
243
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
243
244
|
|
|
@@ -265,7 +266,6 @@ class T5GemmaSelfAttention(nn.Module):
|
|
|
265
266
|
self.o_proj = nn.Linear(
|
|
266
267
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
267
268
|
)
|
|
268
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
269
269
|
self.attn_logit_softcapping = self.config.attn_logit_softcapping
|
|
270
270
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
271
271
|
|
|
@@ -315,6 +315,7 @@ class T5GemmaSelfAttention(nn.Module):
|
|
|
315
315
|
return attn_output, attn_weights
|
|
316
316
|
|
|
317
317
|
|
|
318
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
318
319
|
class T5GemmaCrossAttention(nn.Module):
|
|
319
320
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
320
321
|
|
|
@@ -341,7 +342,6 @@ class T5GemmaCrossAttention(nn.Module):
|
|
|
341
342
|
self.o_proj = nn.Linear(
|
|
342
343
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
343
344
|
)
|
|
344
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
345
345
|
self.attn_logit_softcapping = self.config.attn_logit_softcapping
|
|
346
346
|
|
|
347
347
|
if config.cross_attention_hidden_size is None:
|
|
@@ -32,9 +32,9 @@ logger = logging.get_logger(__name__)
|
|
|
32
32
|
|
|
33
33
|
class T5Gemma2TextConfig(PreTrainedConfig):
|
|
34
34
|
r"""
|
|
35
|
-
This is the configuration class to store the configuration of a [`T5Gemma2TextModel`]. It is used to instantiate
|
|
36
|
-
model according to the specified arguments, defining the model architecture. Instantiating
|
|
37
|
-
defaults will yield a similar configuration to that of the T5Gemma2Text-7B.
|
|
35
|
+
This is the configuration class to store the configuration of a [`T5Gemma2TextModel`]. It is used to instantiate the encoder's
|
|
36
|
+
text model portion of the T5Gemma2 Model according to the specified arguments, defining the model architecture. Instantiating
|
|
37
|
+
a configuration with the defaults will yield a similar configuration to that of the T5Gemma2Text-7B.
|
|
38
38
|
e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b)
|
|
39
39
|
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
|
40
40
|
documentation from [`PreTrainedConfig`] for more information.
|
|
@@ -99,19 +99,6 @@ class T5Gemma2TextConfig(PreTrainedConfig):
|
|
|
99
99
|
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
|
|
100
100
|
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
|
|
101
101
|
with longer `max_position_embeddings`.
|
|
102
|
-
use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
|
|
103
|
-
If True, the model will attend to all text tokens instead of using a causal mask. This does not change
|
|
104
|
-
behavior for vision tokens.
|
|
105
|
-
|
|
106
|
-
```python
|
|
107
|
-
>>> from transformers import T5Gemma2TextModel, T5Gemma2TextConfig
|
|
108
|
-
>>> # Initializing a T5Gemma2Text t5gemma2_text-7b style configuration
|
|
109
|
-
>>> configuration = T5Gemma2TextConfig()
|
|
110
|
-
>>> # Initializing a model from the t5gemma2_text-7b style configuration
|
|
111
|
-
>>> model = T5Gemma2TextModel(configuration)
|
|
112
|
-
>>> # Accessing the model configuration
|
|
113
|
-
>>> configuration = model.config
|
|
114
|
-
```
|
|
115
102
|
"""
|
|
116
103
|
|
|
117
104
|
model_type = "t5gemma2_text"
|
|
@@ -158,7 +145,6 @@ class T5Gemma2TextConfig(PreTrainedConfig):
|
|
|
158
145
|
final_logit_softcapping: Optional[float] = None,
|
|
159
146
|
attn_logit_softcapping: Optional[float] = None,
|
|
160
147
|
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
|
161
|
-
use_bidirectional_attention: Optional[bool] = False,
|
|
162
148
|
**kwargs,
|
|
163
149
|
):
|
|
164
150
|
self.vocab_size = vocab_size
|
|
@@ -181,10 +167,6 @@ class T5Gemma2TextConfig(PreTrainedConfig):
|
|
|
181
167
|
self.attn_logit_softcapping = attn_logit_softcapping
|
|
182
168
|
self.layer_types = layer_types
|
|
183
169
|
|
|
184
|
-
self.use_bidirectional_attention = use_bidirectional_attention
|
|
185
|
-
if use_bidirectional_attention:
|
|
186
|
-
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
|
|
187
|
-
|
|
188
170
|
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
|
189
171
|
self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
|
|
190
172
|
|
|
@@ -326,9 +308,9 @@ class T5Gemma2EncoderConfig(PreTrainedConfig):
|
|
|
326
308
|
|
|
327
309
|
class T5Gemma2DecoderConfig(PreTrainedConfig):
|
|
328
310
|
r"""
|
|
329
|
-
This is the configuration class to store the configuration of a [`T5Gemma2DecoderModel`]. It is used to instantiate
|
|
330
|
-
model according to the specified arguments, defining the model architecture. Instantiating
|
|
331
|
-
defaults will yield a similar configuration to that of the T5Gemma2Decoder-7B.
|
|
311
|
+
This is the configuration class to store the configuration of a [`T5Gemma2DecoderModel`]. It is used to instantiate the decoder
|
|
312
|
+
text model portion of the T5Gemma2 Model according to the specified arguments, defining the model architecture. Instantiating
|
|
313
|
+
a configuration with the defaults will yield a similar configuration to that of the T5Gemma2Decoder-7B.
|
|
332
314
|
e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b)
|
|
333
315
|
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
|
334
316
|
documentation from [`PreTrainedConfig`] for more information.
|
|
@@ -393,19 +375,6 @@ class T5Gemma2DecoderConfig(PreTrainedConfig):
|
|
|
393
375
|
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
|
|
394
376
|
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
|
|
395
377
|
with longer `max_position_embeddings`.
|
|
396
|
-
use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
|
|
397
|
-
If True, the model will attend to all text tokens instead of using a causal mask. This does not change
|
|
398
|
-
behavior for vision tokens.
|
|
399
|
-
|
|
400
|
-
```python
|
|
401
|
-
>>> from transformers import T5Gemma2DecoderModel, T5Gemma2DecoderConfig
|
|
402
|
-
>>> # Initializing a T5Gemma2Decoder t5gemma2_text-7b style configuration
|
|
403
|
-
>>> configuration = T5Gemma2DecoderConfig()
|
|
404
|
-
>>> # Initializing a model from the t5gemma2_text-7b style configuration
|
|
405
|
-
>>> model = T5Gemma2DecoderModel(configuration)
|
|
406
|
-
>>> # Accessing the model configuration
|
|
407
|
-
>>> configuration = model.config
|
|
408
|
-
```
|
|
409
378
|
"""
|
|
410
379
|
|
|
411
380
|
model_type = "t5gemma2_decoder"
|
|
@@ -452,7 +421,6 @@ class T5Gemma2DecoderConfig(PreTrainedConfig):
|
|
|
452
421
|
final_logit_softcapping: Optional[float] = None,
|
|
453
422
|
attn_logit_softcapping: Optional[float] = None,
|
|
454
423
|
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
|
455
|
-
use_bidirectional_attention: Optional[bool] = False,
|
|
456
424
|
**kwargs,
|
|
457
425
|
):
|
|
458
426
|
self.vocab_size = vocab_size
|
|
@@ -475,10 +443,6 @@ class T5Gemma2DecoderConfig(PreTrainedConfig):
|
|
|
475
443
|
self.attn_logit_softcapping = attn_logit_softcapping
|
|
476
444
|
self.layer_types = layer_types
|
|
477
445
|
|
|
478
|
-
self.use_bidirectional_attention = use_bidirectional_attention
|
|
479
|
-
if use_bidirectional_attention:
|
|
480
|
-
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
|
|
481
|
-
|
|
482
446
|
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
|
483
447
|
self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
|
|
484
448
|
|
|
@@ -30,7 +30,7 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
|
32
32
|
from ...generation import GenerationConfig, GenerationMixin, GenerationMode
|
|
33
|
-
from ...integrations import use_kernel_func_from_hub
|
|
33
|
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_bidirectional_mask, create_causal_mask, create_sliding_window_causal_mask
|
|
35
35
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
36
36
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -46,7 +46,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
46
46
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
47
47
|
from ...processing_utils import Unpack
|
|
48
48
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
49
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
49
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
50
50
|
from ..auto import AutoModel
|
|
51
51
|
from .configuration_t5gemma2 import T5Gemma2Config, T5Gemma2DecoderConfig, T5Gemma2EncoderConfig, T5Gemma2TextConfig
|
|
52
52
|
|
|
@@ -113,7 +113,7 @@ class T5Gemma2RotaryEmbedding(nn.Module):
|
|
|
113
113
|
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
|
|
114
114
|
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
|
|
115
115
|
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
|
|
116
|
-
|
|
116
|
+
self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
|
|
117
117
|
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
|
|
118
118
|
|
|
119
119
|
@staticmethod
|
|
@@ -162,7 +162,7 @@ class T5Gemma2RotaryEmbedding(nn.Module):
|
|
|
162
162
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
163
163
|
|
|
164
164
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
165
|
-
with
|
|
165
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
166
166
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
167
167
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
168
168
|
cos = emb.cos() * attention_scaling
|
|
@@ -253,6 +253,7 @@ def eager_attention_forward(
|
|
|
253
253
|
return attn_output, attn_weights
|
|
254
254
|
|
|
255
255
|
|
|
256
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
256
257
|
class T5Gemma2SelfAttention(nn.Module):
|
|
257
258
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
258
259
|
|
|
@@ -265,7 +266,7 @@ class T5Gemma2SelfAttention(nn.Module):
|
|
|
265
266
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
|
266
267
|
self.scaling = config.query_pre_attn_scalar**-0.5
|
|
267
268
|
self.attention_dropout = self.config.attention_dropout
|
|
268
|
-
self.is_causal =
|
|
269
|
+
self.is_causal = False # Only used by the encoder
|
|
269
270
|
|
|
270
271
|
self.q_proj = nn.Linear(
|
|
271
272
|
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
|
@@ -279,7 +280,6 @@ class T5Gemma2SelfAttention(nn.Module):
|
|
|
279
280
|
self.o_proj = nn.Linear(
|
|
280
281
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
281
282
|
)
|
|
282
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
283
283
|
self.attn_logit_softcapping = self.config.attn_logit_softcapping
|
|
284
284
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
285
285
|
self.is_sliding = self.layer_type == "sliding_attention"
|
|
@@ -294,7 +294,7 @@ class T5Gemma2SelfAttention(nn.Module):
|
|
|
294
294
|
attention_mask: Optional[torch.Tensor] = None,
|
|
295
295
|
past_key_values: Optional[Cache] = None,
|
|
296
296
|
cache_position: Optional[torch.LongTensor] = None,
|
|
297
|
-
**kwargs: Unpack[
|
|
297
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
298
298
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
299
299
|
input_shape = hidden_states.shape[:-1]
|
|
300
300
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
@@ -335,6 +335,7 @@ class T5Gemma2SelfAttention(nn.Module):
|
|
|
335
335
|
return attn_output, attn_weights
|
|
336
336
|
|
|
337
337
|
|
|
338
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
338
339
|
class T5Gemma2MergedAttention(nn.Module):
|
|
339
340
|
"""Merged self-attention and cross-attention for decoder."""
|
|
340
341
|
|
|
@@ -347,7 +348,7 @@ class T5Gemma2MergedAttention(nn.Module):
|
|
|
347
348
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
|
348
349
|
self.scaling = config.query_pre_attn_scalar**-0.5
|
|
349
350
|
self.attention_dropout = self.config.attention_dropout
|
|
350
|
-
self.is_causal =
|
|
351
|
+
self.is_causal = False # Fused causal and encoder mask
|
|
351
352
|
|
|
352
353
|
self.q_proj = nn.Linear(
|
|
353
354
|
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
|
@@ -361,7 +362,6 @@ class T5Gemma2MergedAttention(nn.Module):
|
|
|
361
362
|
self.o_proj = nn.Linear(
|
|
362
363
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
363
364
|
)
|
|
364
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
365
365
|
self.attn_logit_softcapping = self.config.attn_logit_softcapping
|
|
366
366
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
367
367
|
self.is_sliding = self.layer_type == "sliding_attention"
|
|
@@ -446,7 +446,6 @@ class T5Gemma2MergedAttention(nn.Module):
|
|
|
446
446
|
merged_attention_mask,
|
|
447
447
|
dropout=self.attention_dropout if self.training else 0.0,
|
|
448
448
|
scaling=self.scaling,
|
|
449
|
-
is_causal=False,
|
|
450
449
|
**kwargs,
|
|
451
450
|
)
|
|
452
451
|
|
|
@@ -649,6 +648,7 @@ class T5Gemma2TextScaledWordEmbedding(nn.Embedding):
|
|
|
649
648
|
eoi_token_index: int = 256_000,
|
|
650
649
|
):
|
|
651
650
|
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
|
651
|
+
self.scalar_embed_scale = embed_scale
|
|
652
652
|
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
|
653
653
|
self.eoi_token_index = eoi_token_index
|
|
654
654
|
self.eoi_embedding = nn.Parameter(torch.zeros(self.embedding_dim))
|
|
@@ -700,6 +700,7 @@ class T5Gemma2PreTrainedModel(PreTrainedModel):
|
|
|
700
700
|
init.zeros_(module.mm_input_projection_weight)
|
|
701
701
|
elif isinstance(module, T5Gemma2TextScaledWordEmbedding):
|
|
702
702
|
init.zeros_(module.eoi_embedding)
|
|
703
|
+
init.constant_(module.embed_scale, module.scalar_embed_scale)
|
|
703
704
|
elif isinstance(module, T5Gemma2ClassificationHead):
|
|
704
705
|
scale = module.out_proj.weight.shape[0] ** -0.5
|
|
705
706
|
init.normal_(module.out_proj.weight, mean=0.0, std=self.config.initializer_range * scale)
|
|
@@ -708,6 +709,14 @@ class T5Gemma2PreTrainedModel(PreTrainedModel):
|
|
|
708
709
|
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
|
709
710
|
elif "RMSNorm" in module.__class__.__name__:
|
|
710
711
|
init.zeros_(module.weight)
|
|
712
|
+
elif isinstance(module, T5Gemma2RotaryEmbedding):
|
|
713
|
+
for layer_type in module.layer_types:
|
|
714
|
+
rope_init_fn = module.compute_default_rope_parameters
|
|
715
|
+
if module.rope_type[layer_type] != "default":
|
|
716
|
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
|
717
|
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
718
|
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
719
|
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
711
720
|
|
|
712
721
|
def prepare_decoder_input_ids_from_labels(self, input_ids):
|
|
713
722
|
"""
|