transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -22,7 +22,7 @@ import torch.nn as nn
|
|
|
22
22
|
|
|
23
23
|
from ... import initialization as init
|
|
24
24
|
from ...cache_utils import DynamicCache, EncoderDecoderCache, StaticCache
|
|
25
|
-
from ...configuration_utils import PreTrainedConfig
|
|
25
|
+
from ...configuration_utils import PreTrainedConfig, layer_type_validation
|
|
26
26
|
from ...generation import GenerationConfig, GenerationMixin, GenerationMode
|
|
27
27
|
from ...masking_utils import create_bidirectional_mask
|
|
28
28
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
@@ -34,6 +34,7 @@ from ...modeling_outputs import (
|
|
|
34
34
|
SequenceClassifierOutput,
|
|
35
35
|
TokenClassifierOutput,
|
|
36
36
|
)
|
|
37
|
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
|
|
37
38
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
38
39
|
from ...processing_utils import Unpack
|
|
39
40
|
from ...utils import (
|
|
@@ -70,9 +71,146 @@ from ..t5gemma.modeling_t5gemma import (
|
|
|
70
71
|
logger = logging.get_logger(__name__)
|
|
71
72
|
|
|
72
73
|
|
|
73
|
-
class T5Gemma2TextConfig(Gemma3TextConfig):
|
|
74
|
+
class T5Gemma2TextConfig(Gemma3TextConfig, PreTrainedConfig):
|
|
75
|
+
r"""
|
|
76
|
+
This is the configuration class to store the configuration of a [`T5Gemma2TextModel`]. It is used to instantiate the encoder's
|
|
77
|
+
text model portion of the T5Gemma2 Model according to the specified arguments, defining the model architecture. Instantiating
|
|
78
|
+
a configuration with the defaults will yield a similar configuration to that of the T5Gemma2Text-7B.
|
|
79
|
+
e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b)
|
|
80
|
+
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
|
81
|
+
documentation from [`PreTrainedConfig`] for more information.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
vocab_size (`int`, *optional*, defaults to 262208):
|
|
85
|
+
Vocabulary size of the T5Gemma2Text model. Defines the number of different tokens that can be represented by the
|
|
86
|
+
`inputs_ids` passed when calling [`T5Gemma2TextModel`]
|
|
87
|
+
hidden_size (`int`, *optional*, defaults to 2304):
|
|
88
|
+
Dimension of the hidden representations.
|
|
89
|
+
intermediate_size (`int`, *optional*, defaults to 9216):
|
|
90
|
+
Dimension of the MLP representations.
|
|
91
|
+
num_hidden_layers (`int`, *optional*, defaults to 26):
|
|
92
|
+
Number of hidden layers in the Transformer decoder.
|
|
93
|
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
|
94
|
+
Number of attention heads for each attention layer in the Transformer decoder.
|
|
95
|
+
num_key_value_heads (`int`, *optional*, defaults to 4):
|
|
96
|
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
97
|
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
98
|
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
99
|
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
100
|
+
by meanpooling all the original heads within that group. For more details, check out [this
|
|
101
|
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
|
102
|
+
`num_attention_heads`.
|
|
103
|
+
head_dim (`int`, *optional*, defaults to 256):
|
|
104
|
+
The attention head dimension.
|
|
105
|
+
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
|
106
|
+
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
|
107
|
+
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
|
108
|
+
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
|
109
|
+
The maximum sequence length that this model might ever be used with.
|
|
110
|
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
111
|
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
112
|
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
|
113
|
+
The epsilon used by the rms normalization layers.
|
|
114
|
+
use_cache (`bool`, *optional*, defaults to `True`):
|
|
115
|
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
116
|
+
relevant if `config.is_decoder=True`.
|
|
117
|
+
pad_token_id (`int`, *optional*, defaults to 0):
|
|
118
|
+
Padding token id.
|
|
119
|
+
eos_token_id (`int`, *optional*, defaults to 1):
|
|
120
|
+
End of stream token id.
|
|
121
|
+
bos_token_id (`int`, *optional*, defaults to 2):
|
|
122
|
+
Beginning of stream token id.
|
|
123
|
+
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
|
124
|
+
Whether to tie weight embeddings
|
|
125
|
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
|
126
|
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
|
127
|
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
128
|
+
The dropout ratio for the attention probabilities.
|
|
129
|
+
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
|
|
130
|
+
Scaling factor used on the attention scores
|
|
131
|
+
sliding_window (`int`, *optional*, defaults to 4096):
|
|
132
|
+
In T5Gemma2Text, every other layer uses sliding window attention. This is the size of the sliding window.
|
|
133
|
+
layer_types (`list`, *optional*):
|
|
134
|
+
Attention pattern for each layer.
|
|
135
|
+
final_logit_softcapping (`float`, *optional*):
|
|
136
|
+
Scaling factor when applying tanh softcapping on the logits.
|
|
137
|
+
attn_logit_softcapping (`float`, *optional*):
|
|
138
|
+
Scaling factor when applying tanh softcapping on the attention scores.
|
|
139
|
+
rope_parameters (`RopeParameters`, *optional*):
|
|
140
|
+
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
|
|
141
|
+
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
|
|
142
|
+
with longer `max_position_embeddings`.
|
|
143
|
+
"""
|
|
144
|
+
|
|
74
145
|
model_type = "t5gemma2_text"
|
|
75
146
|
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
vocab_size: Optional[int] = 262_208,
|
|
150
|
+
hidden_size: Optional[int] = 2304,
|
|
151
|
+
intermediate_size: Optional[int] = 9216,
|
|
152
|
+
num_hidden_layers: Optional[int] = 26,
|
|
153
|
+
num_attention_heads: Optional[int] = 8,
|
|
154
|
+
num_key_value_heads: Optional[int] = 4,
|
|
155
|
+
head_dim: Optional[int] = 256,
|
|
156
|
+
hidden_activation: Optional[str] = "gelu_pytorch_tanh",
|
|
157
|
+
max_position_embeddings: Optional[int] = 131_072,
|
|
158
|
+
initializer_range: Optional[float] = 0.02,
|
|
159
|
+
rms_norm_eps: Optional[int] = 1e-6,
|
|
160
|
+
use_cache: Optional[bool] = True,
|
|
161
|
+
pad_token_id: Optional[int] = 0,
|
|
162
|
+
eos_token_id: Optional[int] = 1,
|
|
163
|
+
bos_token_id: Optional[int] = 2,
|
|
164
|
+
tie_word_embeddings: Optional[bool] = True,
|
|
165
|
+
attention_bias: Optional[bool] = False,
|
|
166
|
+
attention_dropout: Optional[float] = 0.0,
|
|
167
|
+
query_pre_attn_scalar: Optional[int] = 256,
|
|
168
|
+
sliding_window: Optional[int] = 4096,
|
|
169
|
+
layer_types: Optional[list[str]] = None,
|
|
170
|
+
final_logit_softcapping: Optional[float] = None,
|
|
171
|
+
attn_logit_softcapping: Optional[float] = None,
|
|
172
|
+
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
|
173
|
+
**kwargs,
|
|
174
|
+
):
|
|
175
|
+
self.vocab_size = vocab_size
|
|
176
|
+
self.max_position_embeddings = max_position_embeddings
|
|
177
|
+
self.hidden_size = hidden_size
|
|
178
|
+
self.intermediate_size = intermediate_size
|
|
179
|
+
self.num_hidden_layers = num_hidden_layers
|
|
180
|
+
self.num_attention_heads = num_attention_heads
|
|
181
|
+
self.head_dim = head_dim
|
|
182
|
+
self.num_key_value_heads = num_key_value_heads
|
|
183
|
+
self.initializer_range = initializer_range
|
|
184
|
+
self.rms_norm_eps = rms_norm_eps
|
|
185
|
+
self.use_cache = use_cache
|
|
186
|
+
self.attention_bias = attention_bias
|
|
187
|
+
self.attention_dropout = attention_dropout
|
|
188
|
+
self.hidden_activation = hidden_activation
|
|
189
|
+
self.query_pre_attn_scalar = query_pre_attn_scalar
|
|
190
|
+
self.sliding_window = sliding_window
|
|
191
|
+
self.final_logit_softcapping = final_logit_softcapping
|
|
192
|
+
self.attn_logit_softcapping = attn_logit_softcapping
|
|
193
|
+
self.layer_types = layer_types
|
|
194
|
+
|
|
195
|
+
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
|
196
|
+
self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
|
|
197
|
+
|
|
198
|
+
if self.layer_types is None:
|
|
199
|
+
self.layer_types = [
|
|
200
|
+
"sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
|
|
201
|
+
for i in range(self.num_hidden_layers)
|
|
202
|
+
]
|
|
203
|
+
layer_type_validation(self.layer_types, self.num_hidden_layers)
|
|
204
|
+
|
|
205
|
+
self.rope_parameters = rope_parameters
|
|
206
|
+
PreTrainedConfig.__init__(
|
|
207
|
+
pad_token_id=pad_token_id,
|
|
208
|
+
bos_token_id=bos_token_id,
|
|
209
|
+
eos_token_id=eos_token_id,
|
|
210
|
+
tie_word_embeddings=tie_word_embeddings,
|
|
211
|
+
**kwargs,
|
|
212
|
+
)
|
|
213
|
+
|
|
76
214
|
|
|
77
215
|
class T5Gemma2EncoderConfig(Gemma3Config):
|
|
78
216
|
model_type = "t5gemma2_encoder"
|
|
@@ -83,9 +221,146 @@ class T5Gemma2EncoderConfig(Gemma3Config):
|
|
|
83
221
|
}
|
|
84
222
|
|
|
85
223
|
|
|
86
|
-
class T5Gemma2DecoderConfig(Gemma3TextConfig):
|
|
224
|
+
class T5Gemma2DecoderConfig(Gemma3TextConfig, PreTrainedConfig):
|
|
225
|
+
r"""
|
|
226
|
+
This is the configuration class to store the configuration of a [`T5Gemma2DecoderModel`]. It is used to instantiate the decoder
|
|
227
|
+
text model portion of the T5Gemma2 Model according to the specified arguments, defining the model architecture. Instantiating
|
|
228
|
+
a configuration with the defaults will yield a similar configuration to that of the T5Gemma2Decoder-7B.
|
|
229
|
+
e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b)
|
|
230
|
+
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
|
231
|
+
documentation from [`PreTrainedConfig`] for more information.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
vocab_size (`int`, *optional*, defaults to 262208):
|
|
235
|
+
Vocabulary size of the T5Gemma2Decoder model. Defines the number of different tokens that can be represented by the
|
|
236
|
+
`inputs_ids` passed when calling [`T5Gemma2DecoderModel`]
|
|
237
|
+
hidden_size (`int`, *optional*, defaults to 2304):
|
|
238
|
+
Dimension of the hidden representations.
|
|
239
|
+
intermediate_size (`int`, *optional*, defaults to 9216):
|
|
240
|
+
Dimension of the MLP representations.
|
|
241
|
+
num_hidden_layers (`int`, *optional*, defaults to 26):
|
|
242
|
+
Number of hidden layers in the Transformer decoder.
|
|
243
|
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
|
244
|
+
Number of attention heads for each attention layer in the Transformer decoder.
|
|
245
|
+
num_key_value_heads (`int`, *optional*, defaults to 4):
|
|
246
|
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
247
|
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
248
|
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
249
|
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
250
|
+
by meanpooling all the original heads within that group. For more details, check out [this
|
|
251
|
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
|
252
|
+
`num_attention_heads`.
|
|
253
|
+
head_dim (`int`, *optional*, defaults to 256):
|
|
254
|
+
The attention head dimension.
|
|
255
|
+
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
|
256
|
+
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
|
257
|
+
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
|
258
|
+
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
|
259
|
+
The maximum sequence length that this model might ever be used with.
|
|
260
|
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
261
|
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
262
|
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
|
263
|
+
The epsilon used by the rms normalization layers.
|
|
264
|
+
use_cache (`bool`, *optional*, defaults to `True`):
|
|
265
|
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
266
|
+
relevant if `config.is_decoder=True`.
|
|
267
|
+
pad_token_id (`int`, *optional*, defaults to 0):
|
|
268
|
+
Padding token id.
|
|
269
|
+
eos_token_id (`int`, *optional*, defaults to 1):
|
|
270
|
+
End of stream token id.
|
|
271
|
+
bos_token_id (`int`, *optional*, defaults to 2):
|
|
272
|
+
Beginning of stream token id.
|
|
273
|
+
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
|
274
|
+
Whether to tie weight embeddings
|
|
275
|
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
|
276
|
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
|
277
|
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
278
|
+
The dropout ratio for the attention probabilities.
|
|
279
|
+
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
|
|
280
|
+
Scaling factor used on the attention scores
|
|
281
|
+
sliding_window (`int`, *optional*, defaults to 4096):
|
|
282
|
+
In T5Gemma2Decoder, every other layer uses sliding window attention. This is the size of the sliding window.
|
|
283
|
+
layer_types (`list`, *optional*):
|
|
284
|
+
Attention pattern for each layer.
|
|
285
|
+
final_logit_softcapping (`float`, *optional*):
|
|
286
|
+
Scaling factor when applying tanh softcapping on the logits.
|
|
287
|
+
attn_logit_softcapping (`float`, *optional*):
|
|
288
|
+
Scaling factor when applying tanh softcapping on the attention scores.
|
|
289
|
+
rope_parameters (`RopeParameters`, *optional*):
|
|
290
|
+
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
|
|
291
|
+
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
|
|
292
|
+
with longer `max_position_embeddings`.
|
|
293
|
+
"""
|
|
294
|
+
|
|
87
295
|
model_type = "t5gemma2_decoder"
|
|
88
296
|
|
|
297
|
+
def __init__(
|
|
298
|
+
self,
|
|
299
|
+
vocab_size: Optional[int] = 262_208,
|
|
300
|
+
hidden_size: Optional[int] = 2304,
|
|
301
|
+
intermediate_size: Optional[int] = 9216,
|
|
302
|
+
num_hidden_layers: Optional[int] = 26,
|
|
303
|
+
num_attention_heads: Optional[int] = 8,
|
|
304
|
+
num_key_value_heads: Optional[int] = 4,
|
|
305
|
+
head_dim: Optional[int] = 256,
|
|
306
|
+
hidden_activation: Optional[str] = "gelu_pytorch_tanh",
|
|
307
|
+
max_position_embeddings: Optional[int] = 131_072,
|
|
308
|
+
initializer_range: Optional[float] = 0.02,
|
|
309
|
+
rms_norm_eps: Optional[int] = 1e-6,
|
|
310
|
+
use_cache: Optional[bool] = True,
|
|
311
|
+
pad_token_id: Optional[int] = 0,
|
|
312
|
+
eos_token_id: Optional[int] = 1,
|
|
313
|
+
bos_token_id: Optional[int] = 2,
|
|
314
|
+
tie_word_embeddings: Optional[bool] = True,
|
|
315
|
+
attention_bias: Optional[bool] = False,
|
|
316
|
+
attention_dropout: Optional[float] = 0.0,
|
|
317
|
+
query_pre_attn_scalar: Optional[int] = 256,
|
|
318
|
+
sliding_window: Optional[int] = 4096,
|
|
319
|
+
layer_types: Optional[list[str]] = None,
|
|
320
|
+
final_logit_softcapping: Optional[float] = None,
|
|
321
|
+
attn_logit_softcapping: Optional[float] = None,
|
|
322
|
+
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
|
323
|
+
**kwargs,
|
|
324
|
+
):
|
|
325
|
+
self.vocab_size = vocab_size
|
|
326
|
+
self.max_position_embeddings = max_position_embeddings
|
|
327
|
+
self.hidden_size = hidden_size
|
|
328
|
+
self.intermediate_size = intermediate_size
|
|
329
|
+
self.num_hidden_layers = num_hidden_layers
|
|
330
|
+
self.num_attention_heads = num_attention_heads
|
|
331
|
+
self.head_dim = head_dim
|
|
332
|
+
self.num_key_value_heads = num_key_value_heads
|
|
333
|
+
self.initializer_range = initializer_range
|
|
334
|
+
self.rms_norm_eps = rms_norm_eps
|
|
335
|
+
self.use_cache = use_cache
|
|
336
|
+
self.attention_bias = attention_bias
|
|
337
|
+
self.attention_dropout = attention_dropout
|
|
338
|
+
self.hidden_activation = hidden_activation
|
|
339
|
+
self.query_pre_attn_scalar = query_pre_attn_scalar
|
|
340
|
+
self.sliding_window = sliding_window
|
|
341
|
+
self.final_logit_softcapping = final_logit_softcapping
|
|
342
|
+
self.attn_logit_softcapping = attn_logit_softcapping
|
|
343
|
+
self.layer_types = layer_types
|
|
344
|
+
|
|
345
|
+
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
|
346
|
+
self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
|
|
347
|
+
|
|
348
|
+
if self.layer_types is None:
|
|
349
|
+
self.layer_types = [
|
|
350
|
+
"sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
|
|
351
|
+
for i in range(self.num_hidden_layers)
|
|
352
|
+
]
|
|
353
|
+
layer_type_validation(self.layer_types, self.num_hidden_layers)
|
|
354
|
+
|
|
355
|
+
self.rope_parameters = rope_parameters
|
|
356
|
+
PreTrainedConfig.__init__(
|
|
357
|
+
pad_token_id=pad_token_id,
|
|
358
|
+
bos_token_id=bos_token_id,
|
|
359
|
+
eos_token_id=eos_token_id,
|
|
360
|
+
tie_word_embeddings=tie_word_embeddings,
|
|
361
|
+
**kwargs,
|
|
362
|
+
)
|
|
363
|
+
|
|
89
364
|
|
|
90
365
|
class T5Gemma2Config(PreTrainedConfig):
|
|
91
366
|
r"""
|
|
@@ -257,6 +532,7 @@ class T5Gemma2RotaryEmbedding(Gemma3RotaryEmbedding):
|
|
|
257
532
|
class T5Gemma2SelfAttention(Gemma3Attention):
|
|
258
533
|
def __init__(self, config: T5Gemma2TextConfig, layer_idx: int):
|
|
259
534
|
super().__init__(config, layer_idx)
|
|
535
|
+
self.is_causal = False # Only used by the encoder
|
|
260
536
|
|
|
261
537
|
|
|
262
538
|
class T5Gemma2MergedAttention(Gemma3Attention):
|
|
@@ -264,6 +540,7 @@ class T5Gemma2MergedAttention(Gemma3Attention):
|
|
|
264
540
|
|
|
265
541
|
def __init__(self, config: T5Gemma2TextConfig, layer_idx: int):
|
|
266
542
|
super().__init__(config, layer_idx)
|
|
543
|
+
self.is_causal = False # Fused causal and encoder mask
|
|
267
544
|
|
|
268
545
|
def forward(
|
|
269
546
|
self,
|
|
@@ -342,7 +619,6 @@ class T5Gemma2MergedAttention(Gemma3Attention):
|
|
|
342
619
|
merged_attention_mask,
|
|
343
620
|
dropout=self.attention_dropout if self.training else 0.0,
|
|
344
621
|
scaling=self.scaling,
|
|
345
|
-
is_causal=False,
|
|
346
622
|
**kwargs,
|
|
347
623
|
)
|
|
348
624
|
|
|
@@ -498,6 +774,7 @@ class T5Gemma2PreTrainedModel(Gemma3PreTrainedModel):
|
|
|
498
774
|
init.zeros_(module.mm_input_projection_weight)
|
|
499
775
|
elif isinstance(module, T5Gemma2TextScaledWordEmbedding):
|
|
500
776
|
init.zeros_(module.eoi_embedding)
|
|
777
|
+
init.constant_(module.embed_scale, module.scalar_embed_scale)
|
|
501
778
|
elif isinstance(module, T5Gemma2ClassificationHead):
|
|
502
779
|
scale = module.out_proj.weight.shape[0] ** -0.5
|
|
503
780
|
init.normal_(module.out_proj.weight, mean=0.0, std=self.config.initializer_range * scale)
|
|
@@ -506,6 +783,14 @@ class T5Gemma2PreTrainedModel(Gemma3PreTrainedModel):
|
|
|
506
783
|
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
|
507
784
|
elif "RMSNorm" in module.__class__.__name__:
|
|
508
785
|
init.zeros_(module.weight)
|
|
786
|
+
elif isinstance(module, T5Gemma2RotaryEmbedding):
|
|
787
|
+
for layer_type in module.layer_types:
|
|
788
|
+
rope_init_fn = module.compute_default_rope_parameters
|
|
789
|
+
if module.rope_type[layer_type] != "default":
|
|
790
|
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
|
791
|
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
792
|
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
793
|
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
509
794
|
|
|
510
795
|
def prepare_decoder_input_ids_from_labels(self, input_ids):
|
|
511
796
|
"""
|
|
@@ -37,7 +37,7 @@ class TableTransformerConfig(PreTrainedConfig):
|
|
|
37
37
|
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
|
38
38
|
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
|
|
39
39
|
API.
|
|
40
|
-
backbone_config (`PreTrainedConfig
|
|
40
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `ResNetConfig()`):
|
|
41
41
|
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
|
|
42
42
|
case it will default to `ResNetConfig()`.
|
|
43
43
|
num_channels (`int`, *optional*, defaults to 3):
|
|
@@ -702,7 +702,7 @@ class TableTransformerPreTrainedModel(PreTrainedModel):
|
|
|
702
702
|
if isinstance(module, TableTransformerLearnedPositionEmbedding):
|
|
703
703
|
init.uniform_(module.row_embeddings.weight)
|
|
704
704
|
init.uniform_(module.column_embeddings.weight)
|
|
705
|
-
if isinstance(module, (nn.Linear, nn.Conv2d
|
|
705
|
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
706
706
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
707
707
|
if module.bias is not None:
|
|
708
708
|
init.zeros_(module.bias)
|
|
@@ -137,7 +137,6 @@ class TextNetImageProcessorFast(BaseImageProcessorFast):
|
|
|
137
137
|
processed_images_grouped[shape] = stacked_images
|
|
138
138
|
|
|
139
139
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
140
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
141
140
|
|
|
142
141
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
143
142
|
|
|
@@ -144,6 +144,7 @@ class TimesFmPositionalEmbedding(nn.Module):
|
|
|
144
144
|
super().__init__()
|
|
145
145
|
min_timescale = config.min_timescale
|
|
146
146
|
max_timescale = config.max_timescale
|
|
147
|
+
self.min_timescale, self.max_timescale = min_timescale, max_timescale
|
|
147
148
|
self.embedding_dims = config.hidden_size
|
|
148
149
|
|
|
149
150
|
num_timescales = self.embedding_dims // 2
|
|
@@ -313,6 +314,17 @@ class TimesFmPreTrainedModel(PreTrainedModel):
|
|
|
313
314
|
if isinstance(module, TimesFmAttention):
|
|
314
315
|
# Initialize scaling parameter
|
|
315
316
|
init.ones_(module.scaling)
|
|
317
|
+
elif isinstance(module, TimesFmPositionalEmbedding):
|
|
318
|
+
num_timescales = module.embedding_dims // 2
|
|
319
|
+
max_timescale, min_timescale = module.max_timescale, module.min_timescale
|
|
320
|
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
|
|
321
|
+
num_timescales - 1, 1
|
|
322
|
+
)
|
|
323
|
+
init.copy_(
|
|
324
|
+
module.inv_timescales,
|
|
325
|
+
min_timescale
|
|
326
|
+
* torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
|
|
327
|
+
)
|
|
316
328
|
|
|
317
329
|
|
|
318
330
|
@auto_docstring
|
|
@@ -123,6 +123,7 @@ class TimesFmPositionalEmbedding(nn.Module):
|
|
|
123
123
|
super().__init__()
|
|
124
124
|
min_timescale = config.min_timescale
|
|
125
125
|
max_timescale = config.max_timescale
|
|
126
|
+
self.min_timescale, self.max_timescale = min_timescale, max_timescale
|
|
126
127
|
self.embedding_dims = config.hidden_size
|
|
127
128
|
|
|
128
129
|
num_timescales = self.embedding_dims // 2
|
|
@@ -269,6 +270,17 @@ class TimesFmPreTrainedModel(PreTrainedModel):
|
|
|
269
270
|
if isinstance(module, TimesFmAttention):
|
|
270
271
|
# Initialize scaling parameter
|
|
271
272
|
init.ones_(module.scaling)
|
|
273
|
+
elif isinstance(module, TimesFmPositionalEmbedding):
|
|
274
|
+
num_timescales = module.embedding_dims // 2
|
|
275
|
+
max_timescale, min_timescale = module.max_timescale, module.min_timescale
|
|
276
|
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
|
|
277
|
+
num_timescales - 1, 1
|
|
278
|
+
)
|
|
279
|
+
init.copy_(
|
|
280
|
+
module.inv_timescales,
|
|
281
|
+
min_timescale
|
|
282
|
+
* torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
|
|
283
|
+
)
|
|
272
284
|
|
|
273
285
|
|
|
274
286
|
@auto_docstring
|
|
@@ -16,10 +16,12 @@
|
|
|
16
16
|
from typing import Optional, Union
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
|
+
from torch import Tensor, nn
|
|
19
20
|
|
|
21
|
+
from ... import initialization as init
|
|
20
22
|
from ...modeling_outputs import BackboneOutput
|
|
21
23
|
from ...modeling_utils import PreTrainedModel
|
|
22
|
-
from ...utils import is_timm_available,
|
|
24
|
+
from ...utils import is_timm_available, requires_backends
|
|
23
25
|
from ...utils.backbone_utils import BackboneMixin
|
|
24
26
|
from .configuration_timm_backbone import TimmBackboneConfig
|
|
25
27
|
|
|
@@ -28,10 +30,6 @@ if is_timm_available():
|
|
|
28
30
|
import timm
|
|
29
31
|
|
|
30
32
|
|
|
31
|
-
if is_torch_available():
|
|
32
|
-
from torch import Tensor
|
|
33
|
-
|
|
34
|
-
|
|
35
33
|
class TimmBackbone(PreTrainedModel, BackboneMixin):
|
|
36
34
|
"""
|
|
37
35
|
Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the
|
|
@@ -84,10 +82,11 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
|
|
|
84
82
|
self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
|
|
85
83
|
super()._init_backbone(config)
|
|
86
84
|
|
|
85
|
+
self.post_init()
|
|
86
|
+
|
|
87
87
|
@classmethod
|
|
88
88
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
89
89
|
requires_backends(cls, ["vision", "timm"])
|
|
90
|
-
from ...models.timm_backbone import TimmBackboneConfig
|
|
91
90
|
|
|
92
91
|
config = kwargs.pop("config", TimmBackboneConfig())
|
|
93
92
|
|
|
@@ -116,9 +115,14 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
|
|
|
116
115
|
|
|
117
116
|
@torch.no_grad()
|
|
118
117
|
def _init_weights(self, module):
|
|
119
|
-
"""
|
|
120
|
-
|
|
121
|
-
""
|
|
118
|
+
"""We need to at least re-init the non-persistent buffers if the model was initialized on meta device (we
|
|
119
|
+
assume weights and persistent buffers will be part of checkpoint as we have no way to control timm inits)"""
|
|
120
|
+
if hasattr(module, "init_non_persistent_buffers"):
|
|
121
|
+
module.init_non_persistent_buffers()
|
|
122
|
+
elif isinstance(module, nn.BatchNorm2d) and getattr(module, "running_mean", None) is not None:
|
|
123
|
+
init.zeros_(module.running_mean)
|
|
124
|
+
init.ones_(module.running_var)
|
|
125
|
+
init.zeros_(module.num_batches_tracked)
|
|
122
126
|
|
|
123
127
|
def forward(
|
|
124
128
|
self,
|
|
@@ -81,6 +81,9 @@ class TimmWrapperConfig(PreTrainedConfig):
|
|
|
81
81
|
|
|
82
82
|
@classmethod
|
|
83
83
|
def from_dict(cls, config_dict: dict[str, Any], **kwargs):
|
|
84
|
+
# Create a copy to avoid mutating the original dict
|
|
85
|
+
config_dict = config_dict.copy()
|
|
86
|
+
|
|
84
87
|
label_names = config_dict.get("label_names")
|
|
85
88
|
is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
|
|
86
89
|
|
|
@@ -84,16 +84,13 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
|
|
|
84
84
|
main_input_name = "pixel_values"
|
|
85
85
|
input_modalities = ("image",)
|
|
86
86
|
config: TimmWrapperConfig
|
|
87
|
-
|
|
87
|
+
# add WA here as `timm` does not support model parallelism
|
|
88
|
+
_no_split_modules = ["TimmWrapperModel"]
|
|
88
89
|
model_tags = ["timm"]
|
|
89
90
|
|
|
90
91
|
# used in Trainer to avoid passing `loss_kwargs` to model forward
|
|
91
92
|
accepts_loss_kwargs = False
|
|
92
93
|
|
|
93
|
-
def __init__(self, *args, **kwargs):
|
|
94
|
-
requires_backends(self, ["vision", "timm"])
|
|
95
|
-
super().__init__(*args, **kwargs)
|
|
96
|
-
|
|
97
94
|
def post_init(self):
|
|
98
95
|
self.supports_gradient_checkpointing = self._timm_model_supports_gradient_checkpointing()
|
|
99
96
|
super().post_init()
|
|
@@ -113,10 +110,17 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
|
|
|
113
110
|
Since model architectures may vary, we assume only the classifier requires
|
|
114
111
|
initialization, while all other weights should be loaded from the checkpoint.
|
|
115
112
|
"""
|
|
116
|
-
if isinstance(module,
|
|
113
|
+
if isinstance(module, nn.Linear):
|
|
117
114
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
118
115
|
if module.bias is not None:
|
|
119
116
|
init.zeros_(module.bias)
|
|
117
|
+
# Also, reinit all non-persistemt buffers if any!
|
|
118
|
+
if hasattr(module, "init_non_persistent_buffers"):
|
|
119
|
+
module.init_non_persistent_buffers()
|
|
120
|
+
elif isinstance(module, nn.BatchNorm2d) and getattr(module, "running_mean", None) is not None:
|
|
121
|
+
init.zeros_(module.running_mean)
|
|
122
|
+
init.ones_(module.running_var)
|
|
123
|
+
init.zeros_(module.num_batches_tracked)
|
|
120
124
|
|
|
121
125
|
def _timm_model_supports_gradient_checkpointing(self):
|
|
122
126
|
"""
|
|
@@ -136,6 +140,13 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
|
|
|
136
140
|
def _set_gradient_checkpointing(self, enable: bool = True, *args, **kwargs):
|
|
137
141
|
self.timm_model.set_grad_checkpointing(enable)
|
|
138
142
|
|
|
143
|
+
def get_input_embeddings(self):
|
|
144
|
+
# TIMM backbones operate directly on images and do not expose token embeddings.
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
def set_input_embeddings(self, value):
|
|
148
|
+
raise NotImplementedError("TimmWrapper models do not own token embeddings and cannot set them.")
|
|
149
|
+
|
|
139
150
|
|
|
140
151
|
class TimmWrapperModel(TimmWrapperPreTrainedModel):
|
|
141
152
|
"""
|
|
@@ -143,6 +154,7 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
|
|
|
143
154
|
"""
|
|
144
155
|
|
|
145
156
|
def __init__(self, config: TimmWrapperConfig):
|
|
157
|
+
requires_backends(self, ["vision", "timm"])
|
|
146
158
|
super().__init__(config)
|
|
147
159
|
# using num_classes=0 to avoid creating classification head
|
|
148
160
|
extra_init_kwargs = config.model_args or {}
|
|
@@ -150,13 +162,6 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
|
|
|
150
162
|
self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs)
|
|
151
163
|
self.post_init()
|
|
152
164
|
|
|
153
|
-
def get_input_embeddings(self):
|
|
154
|
-
# Vision backbones from timm do not expose token embeddings, so there is nothing to return.
|
|
155
|
-
return None
|
|
156
|
-
|
|
157
|
-
def set_input_embeddings(self, value):
|
|
158
|
-
raise NotImplementedError("TimmWrapperModel does not own token embeddings and cannot set them.")
|
|
159
|
-
|
|
160
165
|
@auto_docstring
|
|
161
166
|
def forward(
|
|
162
167
|
self,
|
|
@@ -265,6 +270,7 @@ class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel):
|
|
|
265
270
|
"""
|
|
266
271
|
|
|
267
272
|
def __init__(self, config: TimmWrapperConfig):
|
|
273
|
+
requires_backends(self, ["vision", "timm"])
|
|
268
274
|
super().__init__(config)
|
|
269
275
|
|
|
270
276
|
if config.num_labels == 0:
|
|
@@ -89,7 +89,6 @@ class TrOCRSinusoidalPositionalEmbedding(nn.Module):
|
|
|
89
89
|
self.embedding_dim = embedding_dim
|
|
90
90
|
self.padding_idx = padding_idx
|
|
91
91
|
self.weights = self.get_embedding(num_positions, embedding_dim, padding_idx)
|
|
92
|
-
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
|
93
92
|
|
|
94
93
|
@staticmethod
|
|
95
94
|
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
|
@@ -123,7 +122,6 @@ class TrOCRSinusoidalPositionalEmbedding(nn.Module):
|
|
|
123
122
|
if self.weights is None or max_pos > self.weights.size(0):
|
|
124
123
|
# recompute/expand embeddings if needed
|
|
125
124
|
self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx)
|
|
126
|
-
self.weights = self.weights.to(self._float_tensor)
|
|
127
125
|
|
|
128
126
|
x = self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
|
129
127
|
|
|
@@ -636,6 +634,7 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
|
|
|
636
634
|
def __init__(self, config):
|
|
637
635
|
super().__init__(config)
|
|
638
636
|
self.decoder = TrOCRDecoder(config)
|
|
637
|
+
self.post_init()
|
|
639
638
|
|
|
640
639
|
def forward(self, *args, **kwargs):
|
|
641
640
|
return self.decoder(*args, **kwargs)
|
|
@@ -35,7 +35,7 @@ class TvpConfig(PreTrainedConfig):
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
|
-
backbone_config (`PreTrainedConfig
|
|
38
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `ResNetConfig()`):
|
|
39
39
|
The configuration of the backbone model.
|
|
40
40
|
backbone (`str`, *optional*):
|
|
41
41
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -68,6 +68,8 @@ class TvpConfig(PreTrainedConfig):
|
|
|
68
68
|
vocab_size (`int`, *optional*, defaults to 30522):
|
|
69
69
|
Vocabulary size of the Tvp text model. Defines the number of different tokens that can be represented by
|
|
70
70
|
the `inputs_ids` passed when calling [`TvpModel`].
|
|
71
|
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
|
72
|
+
The vocabulary size of the `token_type_ids` passed when calling [`TvpModel`].
|
|
71
73
|
hidden_size (`int`, *optional*, defaults to 768):
|
|
72
74
|
Dimensionality of the encoder layers.
|
|
73
75
|
intermediate_size (`int`, *optional*, defaults to 3072):
|
|
@@ -114,6 +116,7 @@ class TvpConfig(PreTrainedConfig):
|
|
|
114
116
|
max_img_size=448,
|
|
115
117
|
num_frames=48,
|
|
116
118
|
vocab_size=30522,
|
|
119
|
+
type_vocab_size=2,
|
|
117
120
|
hidden_size=768,
|
|
118
121
|
intermediate_size=3072,
|
|
119
122
|
num_hidden_layers=12,
|
|
@@ -157,6 +160,7 @@ class TvpConfig(PreTrainedConfig):
|
|
|
157
160
|
self.max_img_size = max_img_size
|
|
158
161
|
self.num_frames = num_frames
|
|
159
162
|
self.vocab_size = vocab_size
|
|
163
|
+
self.type_vocab_size = type_vocab_size
|
|
160
164
|
self.hidden_size = hidden_size
|
|
161
165
|
self.intermediate_size = intermediate_size
|
|
162
166
|
self.num_hidden_layers = num_hidden_layers
|