transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -12,12 +12,11 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
from typing import Optional
|
|
15
|
+
from typing import Optional, Union
|
|
16
16
|
|
|
17
17
|
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
|
|
18
18
|
from tokenizers.models import BPE
|
|
19
19
|
|
|
20
|
-
from ...tokenization_utils_base import generate_merges
|
|
21
20
|
from ...tokenization_utils_tokenizers import TokenizersBackend
|
|
22
21
|
from ...utils import logging
|
|
23
22
|
|
|
@@ -30,7 +29,7 @@ class GemmaTokenizer(TokenizersBackend):
|
|
|
30
29
|
"""
|
|
31
30
|
Construct a fast Gemma tokenizer (backed by HuggingFace's tokenizers library).
|
|
32
31
|
|
|
33
|
-
This tokenizer uses a
|
|
32
|
+
This tokenizer uses a BPE model with byte fallback, no prefix space, and a normalizer that replaces
|
|
34
33
|
spaces with "▁".
|
|
35
34
|
|
|
36
35
|
Args:
|
|
@@ -50,48 +49,37 @@ class GemmaTokenizer(TokenizersBackend):
|
|
|
50
49
|
Whether or not to add a `bos_token` at the start of sequences.
|
|
51
50
|
add_eos_token (`bool`, optional, defaults to False):
|
|
52
51
|
Whether or not to add an `eos_token` at the end of sequences.
|
|
53
|
-
vocab (`dict`, optional):
|
|
52
|
+
vocab (`str` or `dict[str, int]`, optional):
|
|
54
53
|
Custom vocabulary dict. If not provided, a minimal vocabulary is created using the special tokens.
|
|
55
54
|
"""
|
|
56
55
|
|
|
57
56
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
58
|
-
slow_tokenizer_class = None
|
|
59
57
|
padding_side = "left"
|
|
60
58
|
model_input_names = ["input_ids", "attention_mask"]
|
|
59
|
+
model = BPE
|
|
61
60
|
|
|
62
61
|
def __init__(
|
|
63
62
|
self,
|
|
63
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
64
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
64
65
|
unk_token: str = "<unk>",
|
|
65
66
|
bos_token: str = "<bos>",
|
|
66
67
|
eos_token: str = "<eos>",
|
|
67
68
|
pad_token: str = "<pad>",
|
|
68
69
|
mask_token: str = "<mask>",
|
|
69
|
-
add_bos_token: bool = True,
|
|
70
|
-
add_eos_token: bool = False,
|
|
71
|
-
vocab: Optional[dict] = None,
|
|
72
|
-
merges: Optional[list[tuple[str, str]]] = None,
|
|
73
70
|
**kwargs,
|
|
74
71
|
):
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
special_tokens = {str(pad_token), str(eos_token), str(bos_token), str(unk_token)}
|
|
79
|
-
|
|
80
|
-
if vocab is not None:
|
|
81
|
-
self._vocab = (
|
|
82
|
-
{token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
|
|
83
|
-
)
|
|
84
|
-
else:
|
|
85
|
-
self._vocab = {
|
|
72
|
+
if vocab is None:
|
|
73
|
+
vocab = {
|
|
86
74
|
str(pad_token): 0,
|
|
87
75
|
str(eos_token): 1,
|
|
88
76
|
str(bos_token): 2,
|
|
89
77
|
str(unk_token): 3,
|
|
90
78
|
str(mask_token): 4,
|
|
91
79
|
}
|
|
80
|
+
self._vocab = vocab
|
|
81
|
+
self._merges = merges or []
|
|
92
82
|
|
|
93
|
-
filtered_vocab = {t: i for t, i in self._vocab.items() if t not in special_tokens}
|
|
94
|
-
self._merges = merges if merges is not None else generate_merges(filtered_vocab)
|
|
95
83
|
self._tokenizer = Tokenizer(
|
|
96
84
|
BPE(
|
|
97
85
|
vocab=self._vocab,
|
|
@@ -108,17 +96,12 @@ class GemmaTokenizer(TokenizersBackend):
|
|
|
108
96
|
)
|
|
109
97
|
self._tokenizer.normalizer = normalizers.Replace(" ", "▁")
|
|
110
98
|
self._tokenizer.pre_tokenizer = pre_tokenizers.Split(" ", "merged_with_previous")
|
|
111
|
-
tokenizer_object = self._tokenizer
|
|
112
|
-
|
|
113
99
|
super().__init__(
|
|
114
|
-
tokenizer_object=tokenizer_object,
|
|
115
100
|
unk_token=unk_token,
|
|
116
101
|
bos_token=bos_token,
|
|
117
102
|
eos_token=eos_token,
|
|
118
103
|
pad_token=pad_token,
|
|
119
104
|
mask_token=mask_token,
|
|
120
|
-
add_bos_token=add_bos_token,
|
|
121
|
-
add_eos_token=add_eos_token,
|
|
122
105
|
**kwargs,
|
|
123
106
|
)
|
|
124
107
|
|
|
@@ -29,7 +29,7 @@ from ... import initialization as init
|
|
|
29
29
|
from ...activations import ACT2FN
|
|
30
30
|
from ...cache_utils import Cache, DynamicCache
|
|
31
31
|
from ...generation import GenerationMixin
|
|
32
|
-
from ...integrations import use_kernel_func_from_hub
|
|
32
|
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
|
33
33
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
34
34
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
35
35
|
from ...modeling_layers import (
|
|
@@ -42,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
42
42
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
43
43
|
from ...processing_utils import Unpack
|
|
44
44
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
45
|
-
from ...utils.generic import check_model_inputs
|
|
45
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
46
46
|
from .configuration_gemma2 import Gemma2Config
|
|
47
47
|
|
|
48
48
|
|
|
@@ -99,7 +99,7 @@ class Gemma2RotaryEmbedding(nn.Module):
|
|
|
99
99
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
100
100
|
|
|
101
101
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
102
|
-
self.original_inv_freq =
|
|
102
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
103
103
|
|
|
104
104
|
@staticmethod
|
|
105
105
|
def compute_default_rope_parameters(
|
|
@@ -138,7 +138,7 @@ class Gemma2RotaryEmbedding(nn.Module):
|
|
|
138
138
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
139
139
|
|
|
140
140
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
141
|
-
with
|
|
141
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
142
142
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
143
143
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
144
144
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -229,6 +229,7 @@ def eager_attention_forward(
|
|
|
229
229
|
return attn_output, attn_weights
|
|
230
230
|
|
|
231
231
|
|
|
232
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
232
233
|
class Gemma2Attention(nn.Module):
|
|
233
234
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
234
235
|
|
|
@@ -255,7 +256,6 @@ class Gemma2Attention(nn.Module):
|
|
|
255
256
|
self.o_proj = nn.Linear(
|
|
256
257
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
257
258
|
)
|
|
258
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
259
259
|
self.attn_logit_softcapping = self.config.attn_logit_softcapping
|
|
260
260
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
261
261
|
|
|
@@ -34,6 +34,7 @@ from ...modeling_rope_utils import (
|
|
|
34
34
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
35
35
|
from ...processing_utils import Unpack
|
|
36
36
|
from ...utils import TransformersKwargs, logging
|
|
37
|
+
from ...utils.generic import maybe_autocast
|
|
37
38
|
from ..gemma.modeling_gemma import (
|
|
38
39
|
GemmaAttention,
|
|
39
40
|
GemmaForCausalLM,
|
|
@@ -243,7 +244,7 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
|
|
|
243
244
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
244
245
|
|
|
245
246
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
246
|
-
self.original_inv_freq =
|
|
247
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
247
248
|
|
|
248
249
|
@torch.no_grad()
|
|
249
250
|
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
|
@@ -252,7 +253,7 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
|
|
|
252
253
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
253
254
|
|
|
254
255
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
255
|
-
with
|
|
256
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
256
257
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
257
258
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
258
259
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -231,7 +231,6 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
|
|
|
231
231
|
processed_images_grouped[shape] = stacked_images
|
|
232
232
|
|
|
233
233
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
234
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
235
234
|
return BatchFeature(
|
|
236
235
|
data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors
|
|
237
236
|
)
|
|
@@ -31,16 +31,15 @@ from ...activations import ACT2FN
|
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...configuration_utils import PreTrainedConfig
|
|
33
33
|
from ...generation import GenerationMixin
|
|
34
|
-
from ...integrations import use_kernel_func_from_hub
|
|
34
|
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
|
35
35
|
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
|
36
|
-
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
37
36
|
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
|
|
38
37
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
|
39
38
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
40
39
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
41
40
|
from ...processing_utils import Unpack
|
|
42
41
|
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
43
|
-
from ...utils.generic import check_model_inputs
|
|
42
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
44
43
|
from ..auto import AutoModel
|
|
45
44
|
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
|
|
46
45
|
|
|
@@ -101,6 +100,7 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding):
|
|
|
101
100
|
|
|
102
101
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
|
103
102
|
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
|
103
|
+
self.scalar_embed_scale = embed_scale
|
|
104
104
|
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
|
105
105
|
|
|
106
106
|
def forward(self, input_ids: torch.Tensor):
|
|
@@ -166,7 +166,7 @@ class Gemma3RotaryEmbedding(nn.Module):
|
|
|
166
166
|
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
|
|
167
167
|
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
|
|
168
168
|
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
|
|
169
|
-
|
|
169
|
+
self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
|
|
170
170
|
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
|
|
171
171
|
|
|
172
172
|
@staticmethod
|
|
@@ -215,7 +215,7 @@ class Gemma3RotaryEmbedding(nn.Module):
|
|
|
215
215
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
216
216
|
|
|
217
217
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
218
|
-
with
|
|
218
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
219
219
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
220
220
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
221
221
|
cos = emb.cos() * attention_scaling
|
|
@@ -306,6 +306,7 @@ def eager_attention_forward(
|
|
|
306
306
|
return attn_output, attn_weights
|
|
307
307
|
|
|
308
308
|
|
|
309
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
309
310
|
class Gemma3Attention(nn.Module):
|
|
310
311
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
311
312
|
|
|
@@ -332,7 +333,6 @@ class Gemma3Attention(nn.Module):
|
|
|
332
333
|
self.o_proj = nn.Linear(
|
|
333
334
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
334
335
|
)
|
|
335
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
336
336
|
self.attn_logit_softcapping = self.config.attn_logit_softcapping
|
|
337
337
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
338
338
|
self.is_sliding = self.layer_type == "sliding_attention"
|
|
@@ -347,7 +347,7 @@ class Gemma3Attention(nn.Module):
|
|
|
347
347
|
attention_mask: Optional[torch.Tensor] = None,
|
|
348
348
|
past_key_values: Optional[Cache] = None,
|
|
349
349
|
cache_position: Optional[torch.LongTensor] = None,
|
|
350
|
-
**kwargs: Unpack[
|
|
350
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
351
351
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
352
352
|
input_shape = hidden_states.shape[:-1]
|
|
353
353
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
@@ -409,23 +409,19 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
|
|
409
409
|
attention_mask: Optional[torch.Tensor] = None,
|
|
410
410
|
position_ids: Optional[torch.LongTensor] = None,
|
|
411
411
|
past_key_values: Optional[Cache] = None,
|
|
412
|
-
output_attentions: Optional[bool] = False,
|
|
413
|
-
use_cache: Optional[bool] = False,
|
|
414
412
|
cache_position: Optional[torch.LongTensor] = None,
|
|
415
|
-
**kwargs,
|
|
413
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
416
414
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
417
415
|
residual = hidden_states
|
|
418
416
|
|
|
419
417
|
hidden_states = self.input_layernorm(hidden_states)
|
|
420
418
|
|
|
421
|
-
hidden_states,
|
|
419
|
+
hidden_states, _ = self.self_attn(
|
|
422
420
|
hidden_states=hidden_states,
|
|
423
421
|
position_embeddings=position_embeddings,
|
|
424
422
|
attention_mask=attention_mask,
|
|
425
423
|
position_ids=position_ids,
|
|
426
424
|
past_key_values=past_key_values,
|
|
427
|
-
output_attentions=output_attentions,
|
|
428
|
-
use_cache=use_cache,
|
|
429
425
|
cache_position=cache_position,
|
|
430
426
|
**kwargs,
|
|
431
427
|
)
|
|
@@ -438,12 +434,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
|
|
438
434
|
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
|
439
435
|
hidden_states = residual + hidden_states
|
|
440
436
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
if output_attentions:
|
|
444
|
-
outputs += (self_attn_weights,)
|
|
445
|
-
|
|
446
|
-
return outputs
|
|
437
|
+
return hidden_states
|
|
447
438
|
|
|
448
439
|
|
|
449
440
|
@auto_docstring
|
|
@@ -478,6 +469,16 @@ class Gemma3PreTrainedModel(PreTrainedModel):
|
|
|
478
469
|
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
|
479
470
|
elif "RMSNorm" in module.__class__.__name__:
|
|
480
471
|
init.zeros_(module.weight)
|
|
472
|
+
elif isinstance(module, Gemma3TextScaledWordEmbedding):
|
|
473
|
+
init.constant_(module.embed_scale, module.scalar_embed_scale)
|
|
474
|
+
elif isinstance(module, Gemma3RotaryEmbedding):
|
|
475
|
+
for layer_type in module.layer_types:
|
|
476
|
+
rope_init_fn = module.compute_default_rope_parameters
|
|
477
|
+
if module.rope_type[layer_type] != "default":
|
|
478
|
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
|
479
|
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
480
|
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
481
|
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
481
482
|
|
|
482
483
|
|
|
483
484
|
def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
|
|
@@ -527,30 +528,16 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
|
|
527
528
|
past_key_values: Optional[Cache] = None,
|
|
528
529
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
529
530
|
use_cache: Optional[bool] = None,
|
|
530
|
-
output_attentions: Optional[bool] = None,
|
|
531
|
-
output_hidden_states: Optional[bool] = None,
|
|
532
531
|
cache_position: Optional[torch.LongTensor] = None,
|
|
533
532
|
**kwargs: Unpack[TransformersKwargs],
|
|
534
533
|
) -> BaseModelOutputWithPast:
|
|
535
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
536
|
-
output_hidden_states = (
|
|
537
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
538
|
-
)
|
|
539
|
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
540
|
-
|
|
541
534
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
542
535
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
543
536
|
|
|
544
|
-
if self.gradient_checkpointing and self.training and use_cache:
|
|
545
|
-
logger.warning_once(
|
|
546
|
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
547
|
-
)
|
|
548
|
-
use_cache = False
|
|
549
|
-
|
|
550
537
|
if inputs_embeds is None:
|
|
551
538
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
552
539
|
|
|
553
|
-
if use_cache and past_key_values is None
|
|
540
|
+
if use_cache and past_key_values is None:
|
|
554
541
|
past_key_values = DynamicCache(config=self.config)
|
|
555
542
|
|
|
556
543
|
if cache_position is None:
|
|
@@ -591,41 +578,22 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
|
|
591
578
|
for layer_type in self.config.layer_types:
|
|
592
579
|
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
|
593
580
|
|
|
594
|
-
# decoder layers
|
|
595
|
-
all_hidden_states = () if output_hidden_states else None
|
|
596
|
-
all_self_attns = () if output_attentions else None
|
|
597
|
-
|
|
598
581
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
599
|
-
|
|
600
|
-
all_hidden_states += (hidden_states,)
|
|
601
|
-
|
|
602
|
-
layer_outputs = decoder_layer(
|
|
582
|
+
hidden_states = decoder_layer(
|
|
603
583
|
hidden_states,
|
|
604
584
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
|
605
585
|
position_embeddings=position_embeddings[decoder_layer.attention_type],
|
|
606
586
|
position_ids=position_ids,
|
|
607
587
|
past_key_values=past_key_values,
|
|
608
|
-
output_attentions=output_attentions,
|
|
609
|
-
use_cache=use_cache,
|
|
610
588
|
cache_position=cache_position,
|
|
611
589
|
**kwargs,
|
|
612
590
|
)
|
|
613
591
|
|
|
614
|
-
hidden_states = layer_outputs[0]
|
|
615
|
-
|
|
616
|
-
if output_attentions:
|
|
617
|
-
all_self_attns += (layer_outputs[1],)
|
|
618
|
-
|
|
619
592
|
hidden_states = self.norm(hidden_states)
|
|
620
593
|
|
|
621
|
-
if output_hidden_states:
|
|
622
|
-
all_hidden_states += (hidden_states,)
|
|
623
|
-
|
|
624
594
|
return BaseModelOutputWithPast(
|
|
625
595
|
last_hidden_state=hidden_states,
|
|
626
596
|
past_key_values=past_key_values,
|
|
627
|
-
hidden_states=all_hidden_states,
|
|
628
|
-
attentions=all_self_attns,
|
|
629
597
|
)
|
|
630
598
|
|
|
631
599
|
|
|
@@ -797,6 +765,7 @@ def create_causal_mask_mapping(
|
|
|
797
765
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
798
766
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
799
767
|
is_training: bool = False,
|
|
768
|
+
is_first_iteration: Optional[bool] = None,
|
|
800
769
|
**kwargs,
|
|
801
770
|
) -> dict:
|
|
802
771
|
"""
|
|
@@ -819,8 +788,12 @@ def create_causal_mask_mapping(
|
|
|
819
788
|
# NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
|
|
820
789
|
# (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
|
|
821
790
|
# means). Determining prefill in that case requires checking data values, which is not compile-compatible.
|
|
822
|
-
|
|
823
|
-
|
|
791
|
+
is_first_iteration = (
|
|
792
|
+
is_first_iteration
|
|
793
|
+
if is_first_iteration is not None
|
|
794
|
+
else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
|
|
795
|
+
)
|
|
796
|
+
if token_type_ids is not None and is_first_iteration:
|
|
824
797
|
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
|
|
825
798
|
# undo the causal masking)
|
|
826
799
|
|
|
@@ -918,10 +891,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|
|
918
891
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
919
892
|
labels: Optional[torch.LongTensor] = None,
|
|
920
893
|
use_cache: Optional[bool] = None,
|
|
921
|
-
|
|
922
|
-
output_hidden_states: Optional[bool] = None,
|
|
923
|
-
return_dict: Optional[bool] = None,
|
|
924
|
-
**lm_kwargs,
|
|
894
|
+
**lm_kwargs: Unpack[TransformersKwargs],
|
|
925
895
|
) -> Union[tuple, Gemma3ModelOutputWithPast]:
|
|
926
896
|
r"""
|
|
927
897
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -953,12 +923,6 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|
|
953
923
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
954
924
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
955
925
|
|
|
956
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
957
|
-
output_hidden_states = (
|
|
958
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
959
|
-
)
|
|
960
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
961
|
-
|
|
962
926
|
# Replace image id with PAD if the image token if OOV, to avoid index-errors
|
|
963
927
|
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
|
964
928
|
special_image_mask = input_ids == self.config.image_token_id
|
|
@@ -1005,8 +969,6 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|
|
1005
969
|
past_key_values=past_key_values,
|
|
1006
970
|
inputs_embeds=inputs_embeds,
|
|
1007
971
|
use_cache=use_cache,
|
|
1008
|
-
output_attentions=output_attentions,
|
|
1009
|
-
output_hidden_states=output_hidden_states,
|
|
1010
972
|
return_dict=True,
|
|
1011
973
|
cache_position=cache_position,
|
|
1012
974
|
**lm_kwargs,
|
|
@@ -1014,7 +976,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|
|
1014
976
|
|
|
1015
977
|
return Gemma3ModelOutputWithPast(
|
|
1016
978
|
last_hidden_state=outputs.last_hidden_state,
|
|
1017
|
-
past_key_values=outputs.past_key_values
|
|
979
|
+
past_key_values=outputs.past_key_values,
|
|
1018
980
|
hidden_states=outputs.hidden_states,
|
|
1019
981
|
attentions=outputs.attentions,
|
|
1020
982
|
image_hidden_states=image_features if pixel_values is not None else None,
|
|
@@ -1053,6 +1015,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1053
1015
|
def get_image_features(self, pixel_values):
|
|
1054
1016
|
return self.model.get_image_features(pixel_values)
|
|
1055
1017
|
|
|
1018
|
+
@can_return_tuple
|
|
1056
1019
|
@auto_docstring
|
|
1057
1020
|
def forward(
|
|
1058
1021
|
self,
|
|
@@ -1066,11 +1029,8 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1066
1029
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
1067
1030
|
labels: Optional[torch.LongTensor] = None,
|
|
1068
1031
|
use_cache: Optional[bool] = None,
|
|
1069
|
-
output_attentions: Optional[bool] = None,
|
|
1070
|
-
output_hidden_states: Optional[bool] = None,
|
|
1071
|
-
return_dict: Optional[bool] = None,
|
|
1072
1032
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1073
|
-
**lm_kwargs,
|
|
1033
|
+
**lm_kwargs: Unpack[TransformersKwargs],
|
|
1074
1034
|
) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
|
|
1075
1035
|
r"""
|
|
1076
1036
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1116,13 +1076,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1116
1076
|
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
|
|
1117
1077
|
```
|
|
1118
1078
|
"""
|
|
1119
|
-
|
|
1120
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1121
|
-
output_hidden_states = (
|
|
1122
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1123
|
-
)
|
|
1124
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1125
|
-
|
|
1126
1079
|
outputs = self.model(
|
|
1127
1080
|
input_ids=input_ids,
|
|
1128
1081
|
pixel_values=pixel_values,
|
|
@@ -1133,9 +1086,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1133
1086
|
inputs_embeds=inputs_embeds,
|
|
1134
1087
|
use_cache=use_cache,
|
|
1135
1088
|
labels=labels,
|
|
1136
|
-
output_attentions=output_attentions,
|
|
1137
|
-
output_hidden_states=output_hidden_states,
|
|
1138
|
-
return_dict=return_dict,
|
|
1139
1089
|
cache_position=cache_position,
|
|
1140
1090
|
**lm_kwargs,
|
|
1141
1091
|
)
|
|
@@ -1167,10 +1117,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1167
1117
|
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
|
1168
1118
|
loss = loss_fct(flat_logits, flat_labels)
|
|
1169
1119
|
|
|
1170
|
-
if not return_dict:
|
|
1171
|
-
output = (logits,) + outputs[1:]
|
|
1172
|
-
return (loss,) + output if loss is not None else output
|
|
1173
|
-
|
|
1174
1120
|
return Gemma3CausalLMOutputWithPast(
|
|
1175
1121
|
loss=loss,
|
|
1176
1122
|
logits=logits,
|
|
@@ -1193,6 +1139,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1193
1139
|
use_cache=True,
|
|
1194
1140
|
logits_to_keep=None,
|
|
1195
1141
|
labels=None,
|
|
1142
|
+
is_first_iteration=False,
|
|
1196
1143
|
**kwargs,
|
|
1197
1144
|
):
|
|
1198
1145
|
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
|
@@ -1206,12 +1153,15 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1206
1153
|
use_cache=use_cache,
|
|
1207
1154
|
logits_to_keep=logits_to_keep,
|
|
1208
1155
|
token_type_ids=token_type_ids,
|
|
1156
|
+
is_first_iteration=is_first_iteration,
|
|
1209
1157
|
**kwargs,
|
|
1210
1158
|
)
|
|
1211
1159
|
|
|
1212
|
-
#
|
|
1213
|
-
#
|
|
1214
|
-
|
|
1160
|
+
# Pixel values are used only in the first iteration if available
|
|
1161
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
1162
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
1163
|
+
# iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
|
|
1164
|
+
if is_first_iteration or not use_cache:
|
|
1215
1165
|
model_inputs["pixel_values"] = pixel_values
|
|
1216
1166
|
|
|
1217
1167
|
return model_inputs
|
|
@@ -1225,6 +1175,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1225
1175
|
past_key_values: Optional[Cache],
|
|
1226
1176
|
position_ids: Optional[torch.Tensor],
|
|
1227
1177
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
1178
|
+
is_first_iteration: Optional[bool] = False,
|
|
1228
1179
|
**kwargs,
|
|
1229
1180
|
) -> dict:
|
|
1230
1181
|
# Uses the overwritten `create_masks_for_generate` with `token_type_ids` masking
|
|
@@ -1236,7 +1187,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1236
1187
|
past_key_values,
|
|
1237
1188
|
position_ids,
|
|
1238
1189
|
token_type_ids,
|
|
1239
|
-
|
|
1190
|
+
is_first_iteration=is_first_iteration,
|
|
1240
1191
|
**{k: v for k, v in kwargs.items() if k != "pixel_values"},
|
|
1241
1192
|
)
|
|
1242
1193
|
|