transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -293,6 +293,12 @@ class GraniteSpeechPreTrainedModel(PreTrainedModel):
|
|
|
293
293
|
super()._init_weights(module)
|
|
294
294
|
if isinstance(module, GraniteSpeechEncoderProjector):
|
|
295
295
|
init.normal_(module.query)
|
|
296
|
+
elif isinstance(module, GraniteSpeechCTCEncoder):
|
|
297
|
+
context_size = module.config.context_size
|
|
298
|
+
seq = torch.arange(context_size)
|
|
299
|
+
relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
|
|
300
|
+
attention_dists = torch.clamp(relpos_dist, -context_size, context_size) + module.config.max_pos_emb
|
|
301
|
+
init.copy_(module.attention_dists, attention_dists)
|
|
296
302
|
|
|
297
303
|
|
|
298
304
|
@auto_docstring(
|
|
@@ -322,6 +328,12 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
|
|
|
322
328
|
|
|
323
329
|
self.post_init()
|
|
324
330
|
|
|
331
|
+
def set_decoder(self, decoder):
|
|
332
|
+
self.language_model.set_decoder(decoder)
|
|
333
|
+
|
|
334
|
+
def get_decoder(self):
|
|
335
|
+
return self.language_model.get_decoder()
|
|
336
|
+
|
|
325
337
|
def set_input_embeddings(self, value):
|
|
326
338
|
self.language_model.set_input_embeddings(value)
|
|
327
339
|
|
|
@@ -458,6 +470,7 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
|
|
|
458
470
|
attention_mask=None,
|
|
459
471
|
cache_position=None,
|
|
460
472
|
logits_to_keep=None,
|
|
473
|
+
is_first_iteration=False,
|
|
461
474
|
**kwargs,
|
|
462
475
|
):
|
|
463
476
|
# Overwritten -- in specific circumstances we don't want to forward audio inputs to the model
|
|
@@ -469,13 +482,14 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
|
|
|
469
482
|
attention_mask=attention_mask,
|
|
470
483
|
cache_position=cache_position,
|
|
471
484
|
logits_to_keep=logits_to_keep,
|
|
485
|
+
is_first_iteration=is_first_iteration,
|
|
472
486
|
**kwargs,
|
|
473
487
|
)
|
|
474
488
|
|
|
475
489
|
# If we're in cached decoding stage, input_features should be None because
|
|
476
490
|
# input ids do not contain special audio token anymore Otherwise we need
|
|
477
491
|
# input feature values to be passed to the model
|
|
478
|
-
if
|
|
492
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
479
493
|
model_inputs["input_features"] = input_features
|
|
480
494
|
return model_inputs
|
|
481
495
|
|
|
@@ -30,7 +30,7 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
33
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask
|
|
35
35
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
36
36
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
38
38
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
40
|
from ...utils import TransformersKwargs, auto_docstring
|
|
41
|
-
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
41
|
+
from ...utils.generic import can_return_tuple, check_model_inputs, maybe_autocast
|
|
42
42
|
from .configuration_granitemoe import GraniteMoeConfig
|
|
43
43
|
|
|
44
44
|
|
|
@@ -80,7 +80,7 @@ class GraniteMoeRotaryEmbedding(nn.Module):
|
|
|
80
80
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
81
81
|
|
|
82
82
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
83
|
-
self.original_inv_freq =
|
|
83
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
84
84
|
|
|
85
85
|
@staticmethod
|
|
86
86
|
def compute_default_rope_parameters(
|
|
@@ -119,7 +119,7 @@ class GraniteMoeRotaryEmbedding(nn.Module):
|
|
|
119
119
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
120
120
|
|
|
121
121
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
122
|
-
with
|
|
122
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
123
123
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
124
124
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
125
125
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -338,6 +338,7 @@ def eager_attention_forward(
|
|
|
338
338
|
return attn_output, attn_weights
|
|
339
339
|
|
|
340
340
|
|
|
341
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
341
342
|
class GraniteMoeAttention(nn.Module):
|
|
342
343
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
343
344
|
|
|
@@ -363,7 +364,6 @@ class GraniteMoeAttention(nn.Module):
|
|
|
363
364
|
self.o_proj = nn.Linear(
|
|
364
365
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
365
366
|
)
|
|
366
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
367
367
|
|
|
368
368
|
def forward(
|
|
369
369
|
self,
|
|
@@ -456,8 +456,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
|
|
|
456
456
|
_supports_flash_attn = True
|
|
457
457
|
_supports_sdpa = True
|
|
458
458
|
_supports_flex_attn = True
|
|
459
|
-
|
|
460
|
-
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
|
459
|
+
_can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
|
|
461
460
|
_supports_attention_backend = True
|
|
462
461
|
_can_record_outputs = {
|
|
463
462
|
"hidden_states": GraniteMoeDecoderLayer,
|
|
@@ -714,8 +713,6 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
|
|
|
714
713
|
|
|
715
714
|
loss = None
|
|
716
715
|
if labels is not None:
|
|
717
|
-
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
718
|
-
logits = logits.float()
|
|
719
716
|
# Flatten the tokens
|
|
720
717
|
loss = self.loss_function(
|
|
721
718
|
logits,
|
|
@@ -146,8 +146,7 @@ class GraniteMoePreTrainedModel(LlamaPreTrainedModel, PreTrainedModel):
|
|
|
146
146
|
_skip_keys_device_placement = ["past_key_values"]
|
|
147
147
|
_supports_flash_attn = True
|
|
148
148
|
_supports_sdpa = True
|
|
149
|
-
|
|
150
|
-
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
|
149
|
+
_can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
|
|
151
150
|
|
|
152
151
|
@torch.no_grad()
|
|
153
152
|
def _init_weights(self, module):
|
|
@@ -295,8 +294,6 @@ class GraniteMoeForCausalLM(MixtralForCausalLM):
|
|
|
295
294
|
|
|
296
295
|
loss = None
|
|
297
296
|
if labels is not None:
|
|
298
|
-
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
299
|
-
logits = logits.float()
|
|
300
297
|
# Flatten the tokens
|
|
301
298
|
loss = self.loss_function(
|
|
302
299
|
logits,
|
|
@@ -92,6 +92,8 @@ class GraniteMoeHybridConfig(PreTrainedConfig):
|
|
|
92
92
|
allow the model to output the auxiliary loss.
|
|
93
93
|
router_aux_loss_coef (`float`, *optional*, defaults to 0.001): router auxiliary loss coefficient
|
|
94
94
|
shared_intermediate_size (`int`, *optional*, defaults to 1024): intermediate size for shared experts.
|
|
95
|
+
position_embedding_type (`str`, *optional*):
|
|
96
|
+
Positional embedding type to be used; defaults to None. Allowed options: `[None, "rope"]`
|
|
95
97
|
layer_types (`List`, *optional*): list of strings to be used as layer types.
|
|
96
98
|
Allowed choices: "mamba", "attention".
|
|
97
99
|
mamba_n_heads (`int`, *optional*, defaults to 128):
|
|
@@ -159,6 +161,7 @@ class GraniteMoeHybridConfig(PreTrainedConfig):
|
|
|
159
161
|
output_router_logits: Optional[bool] = False,
|
|
160
162
|
router_aux_loss_coef: Optional[float] = 0.001,
|
|
161
163
|
shared_intermediate_size: Optional[int] = 1024,
|
|
164
|
+
position_embedding_type: Optional[str] = None,
|
|
162
165
|
layer_types: Optional[list[str]] = None,
|
|
163
166
|
mamba_n_heads: Optional[int] = 128,
|
|
164
167
|
mamba_n_groups: Optional[int] = 1,
|
|
@@ -198,6 +201,7 @@ class GraniteMoeHybridConfig(PreTrainedConfig):
|
|
|
198
201
|
self.output_router_logits = output_router_logits
|
|
199
202
|
self.router_aux_loss_coef = router_aux_loss_coef
|
|
200
203
|
self.shared_intermediate_size = shared_intermediate_size
|
|
204
|
+
self.position_embedding_type = position_embedding_type
|
|
201
205
|
self.rope_parameters = rope_parameters
|
|
202
206
|
|
|
203
207
|
mamba_intermediate = mamba_expand * hidden_size
|
|
@@ -31,7 +31,12 @@ from transformers.activations import ACT2FN
|
|
|
31
31
|
from ... import initialization as init
|
|
32
32
|
from ...cache_utils import Cache
|
|
33
33
|
from ...generation import GenerationMixin
|
|
34
|
-
from ...integrations import
|
|
34
|
+
from ...integrations import (
|
|
35
|
+
lazy_load_kernel,
|
|
36
|
+
use_kernel_forward_from_hub,
|
|
37
|
+
use_kernel_func_from_hub,
|
|
38
|
+
use_kernelized_func,
|
|
39
|
+
)
|
|
35
40
|
from ...masking_utils import create_causal_mask
|
|
36
41
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
37
42
|
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
@@ -39,23 +44,10 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
39
44
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
40
45
|
from ...processing_utils import Unpack
|
|
41
46
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
42
|
-
from ...utils.generic import check_model_inputs
|
|
43
|
-
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
|
47
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
44
48
|
from .configuration_granitemoehybrid import GraniteMoeHybridConfig
|
|
45
49
|
|
|
46
50
|
|
|
47
|
-
if is_mamba_2_ssm_available():
|
|
48
|
-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
|
49
|
-
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
|
50
|
-
else:
|
|
51
|
-
selective_state_update = None
|
|
52
|
-
|
|
53
|
-
if is_causal_conv1d_available():
|
|
54
|
-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
55
|
-
else:
|
|
56
|
-
causal_conv1d_update, causal_conv1d_fn = None, None
|
|
57
|
-
|
|
58
|
-
|
|
59
51
|
logger = logging.get_logger(__name__)
|
|
60
52
|
|
|
61
53
|
|
|
@@ -132,6 +124,7 @@ def eager_attention_forward(
|
|
|
132
124
|
return attn_output, attn_weights
|
|
133
125
|
|
|
134
126
|
|
|
127
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
135
128
|
class GraniteMoeHybridAttention(nn.Module):
|
|
136
129
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
137
130
|
|
|
@@ -157,7 +150,6 @@ class GraniteMoeHybridAttention(nn.Module):
|
|
|
157
150
|
self.o_proj = nn.Linear(
|
|
158
151
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
159
152
|
)
|
|
160
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
161
153
|
|
|
162
154
|
def forward(
|
|
163
155
|
self,
|
|
@@ -165,6 +157,7 @@ class GraniteMoeHybridAttention(nn.Module):
|
|
|
165
157
|
attention_mask: Optional[torch.Tensor],
|
|
166
158
|
past_key_values: Optional[Cache] = None,
|
|
167
159
|
cache_position: Optional[torch.LongTensor] = None,
|
|
160
|
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
|
|
168
161
|
**kwargs: Unpack[TransformersKwargs],
|
|
169
162
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
170
163
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -174,6 +167,10 @@ class GraniteMoeHybridAttention(nn.Module):
|
|
|
174
167
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
175
168
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
176
169
|
|
|
170
|
+
if position_embeddings is not None:
|
|
171
|
+
cos, sin = position_embeddings
|
|
172
|
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
173
|
+
|
|
177
174
|
if past_key_values is not None:
|
|
178
175
|
cache_kwargs = {"cache_position": cache_position}
|
|
179
176
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
@@ -371,9 +368,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
|
|
|
371
368
|
return hidden_states
|
|
372
369
|
|
|
373
370
|
|
|
374
|
-
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
|
375
|
-
|
|
376
|
-
|
|
377
371
|
# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
|
|
378
372
|
class GraniteMoeHybridMambaLayer(nn.Module):
|
|
379
373
|
"""
|
|
@@ -445,6 +439,20 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
|
|
445
439
|
|
|
446
440
|
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
|
|
447
441
|
|
|
442
|
+
global causal_conv1d_update, causal_conv1d_fn
|
|
443
|
+
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
444
|
+
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
|
|
445
|
+
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
|
|
446
|
+
|
|
447
|
+
global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
|
448
|
+
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
|
449
|
+
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
|
|
450
|
+
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
|
|
451
|
+
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
|
|
452
|
+
|
|
453
|
+
global is_fast_path_available
|
|
454
|
+
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
|
455
|
+
|
|
448
456
|
if not is_fast_path_available:
|
|
449
457
|
logger.warning_once(
|
|
450
458
|
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
|
|
@@ -915,7 +923,7 @@ class GraniteMoeHybridRotaryEmbedding(nn.Module):
|
|
|
915
923
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
916
924
|
|
|
917
925
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
918
|
-
self.original_inv_freq =
|
|
926
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
919
927
|
|
|
920
928
|
@staticmethod
|
|
921
929
|
def compute_default_rope_parameters(
|
|
@@ -954,7 +962,7 @@ class GraniteMoeHybridRotaryEmbedding(nn.Module):
|
|
|
954
962
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
955
963
|
|
|
956
964
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
957
|
-
with
|
|
965
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
958
966
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
959
967
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
960
968
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -1231,8 +1239,7 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel):
|
|
|
1231
1239
|
_supports_flash_attn = True
|
|
1232
1240
|
_supports_sdpa = True
|
|
1233
1241
|
_supports_flex_attn = True
|
|
1234
|
-
|
|
1235
|
-
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
|
1242
|
+
_can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
|
|
1236
1243
|
_supports_attention_backend = True
|
|
1237
1244
|
_can_record_outputs = {
|
|
1238
1245
|
"hidden_states": GraniteMoeHybridDecoderLayer,
|
|
@@ -1265,7 +1272,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
|
|
|
1265
1272
|
[GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
1266
1273
|
)
|
|
1267
1274
|
self.norm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
1268
|
-
self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config
|
|
1275
|
+
self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config) if config.position_embedding_type == "rope" else None
|
|
1269
1276
|
self.gradient_checkpointing = False
|
|
1270
1277
|
self.embedding_multiplier = config.embedding_multiplier
|
|
1271
1278
|
|
|
@@ -1313,7 +1320,9 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
|
|
|
1313
1320
|
|
|
1314
1321
|
# embed positions
|
|
1315
1322
|
hidden_states = inputs_embeds
|
|
1316
|
-
position_embeddings =
|
|
1323
|
+
position_embeddings = None
|
|
1324
|
+
if self.rotary_emb is not None:
|
|
1325
|
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
1317
1326
|
|
|
1318
1327
|
for decoder_layer in self.layers:
|
|
1319
1328
|
# Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
|
|
@@ -1510,8 +1519,6 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
|
|
|
1510
1519
|
|
|
1511
1520
|
loss = None
|
|
1512
1521
|
if labels is not None:
|
|
1513
|
-
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
1514
|
-
logits = logits.float()
|
|
1515
1522
|
# Flatten the tokens
|
|
1516
1523
|
loss = self.loss_function(
|
|
1517
1524
|
logits,
|
|
@@ -1549,6 +1556,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
|
|
|
1549
1556
|
cache_position=None,
|
|
1550
1557
|
position_ids=None,
|
|
1551
1558
|
use_cache=True,
|
|
1559
|
+
is_first_iteration=False,
|
|
1552
1560
|
**kwargs,
|
|
1553
1561
|
):
|
|
1554
1562
|
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
|
|
@@ -1581,7 +1589,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
|
|
|
1581
1589
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
1582
1590
|
|
|
1583
1591
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
1584
|
-
if inputs_embeds is not None and
|
|
1592
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
1585
1593
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
1586
1594
|
else:
|
|
1587
1595
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
@@ -39,6 +39,7 @@ from ..granitemoeshared.modeling_granitemoeshared import (
|
|
|
39
39
|
GraniteMoeSharedModel,
|
|
40
40
|
GraniteMoeSharedMoE,
|
|
41
41
|
GraniteMoeSharedPreTrainedModel,
|
|
42
|
+
apply_rotary_pos_emb,
|
|
42
43
|
eager_attention_forward,
|
|
43
44
|
)
|
|
44
45
|
from .configuration_granitemoehybrid import GraniteMoeHybridConfig
|
|
@@ -57,6 +58,7 @@ class GraniteMoeHybridAttention(GraniteMoeSharedAttention):
|
|
|
57
58
|
attention_mask: Optional[torch.Tensor],
|
|
58
59
|
past_key_values: Optional[Cache] = None,
|
|
59
60
|
cache_position: Optional[torch.LongTensor] = None,
|
|
61
|
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
|
|
60
62
|
**kwargs: Unpack[TransformersKwargs],
|
|
61
63
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
62
64
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -66,6 +68,10 @@ class GraniteMoeHybridAttention(GraniteMoeSharedAttention):
|
|
|
66
68
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
67
69
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
68
70
|
|
|
71
|
+
if position_embeddings is not None:
|
|
72
|
+
cos, sin = position_embeddings
|
|
73
|
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
74
|
+
|
|
69
75
|
if past_key_values is not None:
|
|
70
76
|
cache_kwargs = {"cache_position": cache_position}
|
|
71
77
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
@@ -203,6 +209,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
|
|
|
203
209
|
[GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
204
210
|
)
|
|
205
211
|
self.embedding_multiplier = config.embedding_multiplier
|
|
212
|
+
self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config) if config.position_embedding_type == "rope" else None
|
|
206
213
|
|
|
207
214
|
@auto_docstring
|
|
208
215
|
@check_model_inputs
|
|
@@ -245,7 +252,9 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
|
|
|
245
252
|
|
|
246
253
|
# embed positions
|
|
247
254
|
hidden_states = inputs_embeds
|
|
248
|
-
position_embeddings =
|
|
255
|
+
position_embeddings = None
|
|
256
|
+
if self.rotary_emb is not None:
|
|
257
|
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
249
258
|
|
|
250
259
|
for decoder_layer in self.layers:
|
|
251
260
|
# Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
|
|
@@ -300,6 +309,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM):
|
|
|
300
309
|
cache_position=None,
|
|
301
310
|
position_ids=None,
|
|
302
311
|
use_cache=True,
|
|
312
|
+
is_first_iteration=False,
|
|
303
313
|
**kwargs,
|
|
304
314
|
):
|
|
305
315
|
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
|
|
@@ -332,7 +342,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM):
|
|
|
332
342
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
333
343
|
|
|
334
344
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
335
|
-
if inputs_embeds is not None and
|
|
345
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
336
346
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
337
347
|
else:
|
|
338
348
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
@@ -30,7 +30,7 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
33
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask
|
|
35
35
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
36
36
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
38
38
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
40
|
from ...utils import TransformersKwargs, auto_docstring
|
|
41
|
-
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
41
|
+
from ...utils.generic import can_return_tuple, check_model_inputs, maybe_autocast
|
|
42
42
|
from .configuration_granitemoeshared import GraniteMoeSharedConfig
|
|
43
43
|
|
|
44
44
|
|
|
@@ -328,6 +328,7 @@ def eager_attention_forward(
|
|
|
328
328
|
return attn_output, attn_weights
|
|
329
329
|
|
|
330
330
|
|
|
331
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
331
332
|
class GraniteMoeSharedAttention(nn.Module):
|
|
332
333
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
333
334
|
|
|
@@ -353,7 +354,6 @@ class GraniteMoeSharedAttention(nn.Module):
|
|
|
353
354
|
self.o_proj = nn.Linear(
|
|
354
355
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
355
356
|
)
|
|
356
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
357
357
|
|
|
358
358
|
def forward(
|
|
359
359
|
self,
|
|
@@ -462,8 +462,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel):
|
|
|
462
462
|
_supports_flash_attn = True
|
|
463
463
|
_supports_sdpa = True
|
|
464
464
|
_supports_flex_attn = True
|
|
465
|
-
|
|
466
|
-
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
|
465
|
+
_can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
|
|
467
466
|
_supports_attention_backend = True
|
|
468
467
|
_can_record_outputs = {
|
|
469
468
|
"hidden_states": GraniteMoeSharedDecoderLayer,
|
|
@@ -494,7 +493,7 @@ class GraniteMoeSharedRotaryEmbedding(nn.Module):
|
|
|
494
493
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
495
494
|
|
|
496
495
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
497
|
-
self.original_inv_freq =
|
|
496
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
498
497
|
|
|
499
498
|
@staticmethod
|
|
500
499
|
def compute_default_rope_parameters(
|
|
@@ -533,7 +532,7 @@ class GraniteMoeSharedRotaryEmbedding(nn.Module):
|
|
|
533
532
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
534
533
|
|
|
535
534
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
536
|
-
with
|
|
535
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
537
536
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
538
537
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
539
538
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -785,8 +784,6 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
|
|
|
785
784
|
|
|
786
785
|
loss = None
|
|
787
786
|
if labels is not None:
|
|
788
|
-
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
789
|
-
logits = logits.float()
|
|
790
787
|
# Flatten the tokens
|
|
791
788
|
loss = self.loss_function(
|
|
792
789
|
logits,
|
|
@@ -34,7 +34,7 @@ class GroundingDinoConfig(PreTrainedConfig):
|
|
|
34
34
|
documentation from [`PreTrainedConfig`] for more information.
|
|
35
35
|
|
|
36
36
|
Args:
|
|
37
|
-
backbone_config (`PreTrainedConfig
|
|
37
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `SwinConfig()`):
|
|
38
38
|
The configuration of the backbone model.
|
|
39
39
|
backbone (`str`, *optional*):
|
|
40
40
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -285,9 +285,8 @@ class GroundingDinoConfig(PreTrainedConfig):
|
|
|
285
285
|
self.positional_embedding_temperature = positional_embedding_temperature
|
|
286
286
|
self.init_std = init_std
|
|
287
287
|
self.layer_norm_eps = layer_norm_eps
|
|
288
|
+
|
|
288
289
|
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
|
289
|
-
self.tie_encoder_decoder = True
|
|
290
|
-
self.tie_encoder_decoder = True
|
|
291
290
|
|
|
292
291
|
|
|
293
292
|
__all__ = ["GroundingDinoConfig"]
|
|
@@ -1415,7 +1415,7 @@ class GroundingDinoPreTrainedModel(PreTrainedModel):
|
|
|
1415
1415
|
elif isinstance(module, GroundingDinoFusionLayer):
|
|
1416
1416
|
init.constant_(module.vision_param, 1e-4)
|
|
1417
1417
|
init.constant_(module.text_param, 1e-4)
|
|
1418
|
-
elif isinstance(module, (nn.Linear, nn.Conv2d
|
|
1418
|
+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
1419
1419
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
1420
1420
|
if module.bias is not None:
|
|
1421
1421
|
init.zeros_(module.bias)
|
|
@@ -1510,7 +1510,8 @@ class GroundingDinoEncoder(GroundingDinoPreTrainedModel):
|
|
|
1510
1510
|
output_attentions=None,
|
|
1511
1511
|
output_hidden_states=None,
|
|
1512
1512
|
return_dict=None,
|
|
1513
|
-
|
|
1513
|
+
**kwargs,
|
|
1514
|
+
) -> Union[tuple, GroundingDinoEncoderOutput]:
|
|
1514
1515
|
r"""
|
|
1515
1516
|
Args:
|
|
1516
1517
|
vision_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
@@ -1664,7 +1665,8 @@ class GroundingDinoDecoder(GroundingDinoPreTrainedModel):
|
|
|
1664
1665
|
output_attentions=None,
|
|
1665
1666
|
output_hidden_states=None,
|
|
1666
1667
|
return_dict=None,
|
|
1667
|
-
|
|
1668
|
+
**kwargs,
|
|
1669
|
+
) -> Union[tuple, GroundingDinoDecoderOutput]:
|
|
1668
1670
|
r"""
|
|
1669
1671
|
Args:
|
|
1670
1672
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
@@ -2056,7 +2058,8 @@ class GroundingDinoModel(GroundingDinoPreTrainedModel):
|
|
|
2056
2058
|
output_attentions=None,
|
|
2057
2059
|
output_hidden_states=None,
|
|
2058
2060
|
return_dict=None,
|
|
2059
|
-
|
|
2061
|
+
**kwargs,
|
|
2062
|
+
) -> Union[tuple, GroundingDinoModelOutput]:
|
|
2060
2063
|
r"""
|
|
2061
2064
|
input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`):
|
|
2062
2065
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
@@ -2460,6 +2463,7 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
|
|
|
2460
2463
|
output_hidden_states: Optional[bool] = None,
|
|
2461
2464
|
return_dict: Optional[bool] = None,
|
|
2462
2465
|
labels: Optional[list[dict[str, Union[torch.LongTensor, torch.FloatTensor]]]] = None,
|
|
2466
|
+
**kwargs,
|
|
2463
2467
|
):
|
|
2464
2468
|
r"""
|
|
2465
2469
|
input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`):
|
|
@@ -758,14 +758,19 @@ class GroupViTPreTrainedModel(PreTrainedModel):
|
|
|
758
758
|
init.normal_(module.weight, mean=0.0, std=init_range)
|
|
759
759
|
if module.bias is not None:
|
|
760
760
|
init.zeros_(module.bias)
|
|
761
|
-
elif isinstance(module, nn.LayerNorm):
|
|
761
|
+
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
|
|
762
762
|
init.zeros_(module.bias)
|
|
763
763
|
init.ones_(module.weight)
|
|
764
|
+
if getattr(module, "running_mean", None) is not None:
|
|
765
|
+
init.zeros_(module.running_mean)
|
|
766
|
+
init.ones_(module.running_var)
|
|
767
|
+
init.zeros_(module.num_batches_tracked)
|
|
764
768
|
|
|
765
769
|
factor = self.config.initializer_factor
|
|
766
770
|
if isinstance(module, GroupViTTextEmbeddings):
|
|
767
771
|
init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
|
|
768
772
|
init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
|
|
773
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
769
774
|
elif isinstance(module, GroupViTAttention):
|
|
770
775
|
factor = self.config.initializer_factor
|
|
771
776
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
@@ -1045,6 +1050,7 @@ class GroupViTTextModel(GroupViTPreTrainedModel):
|
|
|
1045
1050
|
output_attentions: Optional[bool] = None,
|
|
1046
1051
|
output_hidden_states: Optional[bool] = None,
|
|
1047
1052
|
return_dict: Optional[bool] = None,
|
|
1053
|
+
**kwargs,
|
|
1048
1054
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
|
1049
1055
|
r"""
|
|
1050
1056
|
Examples:
|
|
@@ -1145,6 +1151,7 @@ class GroupViTVisionModel(GroupViTPreTrainedModel):
|
|
|
1145
1151
|
output_attentions: Optional[bool] = None,
|
|
1146
1152
|
output_hidden_states: Optional[bool] = None,
|
|
1147
1153
|
return_dict: Optional[bool] = None,
|
|
1154
|
+
**kwargs,
|
|
1148
1155
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
|
1149
1156
|
r"""
|
|
1150
1157
|
Examples:
|
|
@@ -1297,6 +1304,7 @@ class GroupViTModel(GroupViTPreTrainedModel):
|
|
|
1297
1304
|
output_hidden_states: Optional[bool] = None,
|
|
1298
1305
|
output_segmentation: Optional[bool] = None,
|
|
1299
1306
|
return_dict: Optional[bool] = None,
|
|
1307
|
+
**kwargs,
|
|
1300
1308
|
) -> Union[tuple, GroupViTModelOutput]:
|
|
1301
1309
|
r"""
|
|
1302
1310
|
return_loss (`bool`, *optional*):
|
|
@@ -29,6 +29,7 @@ import torch.nn as nn
|
|
|
29
29
|
from ...activations import ACT2FN
|
|
30
30
|
from ...cache_utils import Cache, DynamicCache
|
|
31
31
|
from ...generation import GenerationMixin
|
|
32
|
+
from ...integrations import use_kernelized_func
|
|
32
33
|
from ...masking_utils import create_causal_mask
|
|
33
34
|
from ...modeling_layers import (
|
|
34
35
|
GenericForSequenceClassification,
|
|
@@ -40,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
40
41
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
41
42
|
from ...processing_utils import Unpack
|
|
42
43
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
43
|
-
from ...utils.generic import check_model_inputs
|
|
44
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
44
45
|
from .configuration_helium import HeliumConfig
|
|
45
46
|
|
|
46
47
|
|
|
@@ -78,7 +79,7 @@ class HeliumRotaryEmbedding(nn.Module):
|
|
|
78
79
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
79
80
|
|
|
80
81
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
81
|
-
self.original_inv_freq =
|
|
82
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
82
83
|
|
|
83
84
|
@staticmethod
|
|
84
85
|
def compute_default_rope_parameters(
|
|
@@ -117,7 +118,7 @@ class HeliumRotaryEmbedding(nn.Module):
|
|
|
117
118
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
118
119
|
|
|
119
120
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
120
|
-
with
|
|
121
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
121
122
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
122
123
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
123
124
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -220,6 +221,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
220
221
|
return q_embed, k_embed
|
|
221
222
|
|
|
222
223
|
|
|
224
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
223
225
|
class HeliumAttention(nn.Module):
|
|
224
226
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
225
227
|
|
|
@@ -243,7 +245,6 @@ class HeliumAttention(nn.Module):
|
|
|
243
245
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
244
246
|
)
|
|
245
247
|
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
|
246
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
247
248
|
|
|
248
249
|
def forward(
|
|
249
250
|
self,
|