transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -31,7 +31,7 @@ from ... import initialization as init
|
|
|
31
31
|
from ...activations import ACT2FN
|
|
32
32
|
from ...cache_utils import Cache, DynamicCache
|
|
33
33
|
from ...generation import GenerationMixin
|
|
34
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
34
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
35
35
|
from ...masking_utils import create_bidirectional_mask, create_causal_mask
|
|
36
36
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
37
37
|
from ...modeling_outputs import (
|
|
@@ -45,7 +45,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
45
45
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
46
46
|
from ...processing_utils import Unpack
|
|
47
47
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
48
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
48
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
49
49
|
from .configuration_evolla import EvollaConfig, SaProtConfig
|
|
50
50
|
|
|
51
51
|
|
|
@@ -185,6 +185,7 @@ class EvollaSaProtRotaryEmbedding(nn.Module):
|
|
|
185
185
|
|
|
186
186
|
def __init__(self, dim: int):
|
|
187
187
|
super().__init__()
|
|
188
|
+
self.dim = dim
|
|
188
189
|
# Generate and save the inverse frequency buffer (non trainable)
|
|
189
190
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
|
190
191
|
self.register_buffer("inv_freq", inv_freq)
|
|
@@ -518,12 +519,19 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel):
|
|
|
518
519
|
],
|
|
519
520
|
}
|
|
520
521
|
|
|
522
|
+
def _init_weights(self, module):
|
|
523
|
+
super()._init_weights(module)
|
|
524
|
+
if isinstance(module, EvollaSaProtRotaryEmbedding):
|
|
525
|
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
|
|
526
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
527
|
+
|
|
521
528
|
|
|
522
529
|
class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
|
|
523
530
|
def __init__(self, config: SaProtConfig):
|
|
524
531
|
super().__init__(config)
|
|
525
532
|
self.embeddings = EvollaSaProtEmbeddings(config)
|
|
526
533
|
self.encoder = EvollaSaProtEncoder(config)
|
|
534
|
+
self.post_init()
|
|
527
535
|
|
|
528
536
|
def get_input_embeddings(self):
|
|
529
537
|
return self.embeddings.word_embeddings
|
|
@@ -980,7 +988,7 @@ class EvollaRotaryEmbedding(nn.Module):
|
|
|
980
988
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
981
989
|
|
|
982
990
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
983
|
-
self.original_inv_freq =
|
|
991
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
984
992
|
|
|
985
993
|
@staticmethod
|
|
986
994
|
def compute_default_rope_parameters(
|
|
@@ -1019,7 +1027,7 @@ class EvollaRotaryEmbedding(nn.Module):
|
|
|
1019
1027
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
1020
1028
|
|
|
1021
1029
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
1022
|
-
with
|
|
1030
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
1023
1031
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
1024
1032
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1025
1033
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -1091,6 +1099,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
1091
1099
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
1092
1100
|
|
|
1093
1101
|
|
|
1102
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
1094
1103
|
class EvollaAttention(nn.Module):
|
|
1095
1104
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
1096
1105
|
|
|
@@ -1116,7 +1125,6 @@ class EvollaAttention(nn.Module):
|
|
|
1116
1125
|
self.o_proj = nn.Linear(
|
|
1117
1126
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
1118
1127
|
)
|
|
1119
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
1120
1128
|
|
|
1121
1129
|
def forward(
|
|
1122
1130
|
self,
|
|
@@ -91,6 +91,7 @@ class EvollaSaProtRotaryEmbedding(nn.Module):
|
|
|
91
91
|
|
|
92
92
|
def __init__(self, dim: int):
|
|
93
93
|
super().__init__()
|
|
94
|
+
self.dim = dim
|
|
94
95
|
# Generate and save the inverse frequency buffer (non trainable)
|
|
95
96
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
|
96
97
|
self.register_buffer("inv_freq", inv_freq)
|
|
@@ -203,12 +204,19 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel):
|
|
|
203
204
|
],
|
|
204
205
|
}
|
|
205
206
|
|
|
207
|
+
def _init_weights(self, module):
|
|
208
|
+
super()._init_weights(module)
|
|
209
|
+
if isinstance(module, EvollaSaProtRotaryEmbedding):
|
|
210
|
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
|
|
211
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
212
|
+
|
|
206
213
|
|
|
207
214
|
class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
|
|
208
215
|
def __init__(self, config: SaProtConfig):
|
|
209
216
|
super().__init__(config)
|
|
210
217
|
self.embeddings = EvollaSaProtEmbeddings(config)
|
|
211
218
|
self.encoder = EvollaSaProtEncoder(config)
|
|
219
|
+
self.post_init()
|
|
212
220
|
|
|
213
221
|
def get_input_embeddings(self):
|
|
214
222
|
return self.embeddings.word_embeddings
|
|
@@ -44,6 +44,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
44
44
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
45
45
|
from ...processing_utils import Unpack
|
|
46
46
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
47
|
+
from ...utils.generic import maybe_autocast
|
|
47
48
|
from .configuration_exaone4 import Exaone4Config
|
|
48
49
|
|
|
49
50
|
|
|
@@ -85,7 +86,7 @@ class Exaone4RotaryEmbedding(nn.Module):
|
|
|
85
86
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
86
87
|
|
|
87
88
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
88
|
-
self.original_inv_freq =
|
|
89
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
89
90
|
|
|
90
91
|
@staticmethod
|
|
91
92
|
def compute_default_rope_parameters(
|
|
@@ -124,7 +125,7 @@ class Exaone4RotaryEmbedding(nn.Module):
|
|
|
124
125
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
125
126
|
|
|
126
127
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
127
|
-
with
|
|
128
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
128
129
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
129
130
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
130
131
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -239,7 +240,6 @@ class Exaone4Attention(nn.Module):
|
|
|
239
240
|
attention_mask: Optional[torch.Tensor] = None,
|
|
240
241
|
past_key_values: Optional[Cache] = None,
|
|
241
242
|
cache_position: Optional[torch.LongTensor] = None,
|
|
242
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
243
243
|
**kwargs: Unpack[TransformersKwargs],
|
|
244
244
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
245
245
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -260,7 +260,6 @@ class Exaone4Attention(nn.Module):
|
|
|
260
260
|
attention_mask: Optional[torch.Tensor] = None,
|
|
261
261
|
past_key_values: Optional[Cache] = None,
|
|
262
262
|
cache_position: Optional[torch.LongTensor] = None,
|
|
263
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
264
263
|
**kwargs: Unpack[TransformersKwargs],
|
|
265
264
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
266
265
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -48,6 +48,7 @@ from ...utils import (
|
|
|
48
48
|
auto_docstring,
|
|
49
49
|
logging,
|
|
50
50
|
)
|
|
51
|
+
from ...utils.generic import maybe_autocast
|
|
51
52
|
from .configuration_falcon import FalconConfig
|
|
52
53
|
|
|
53
54
|
|
|
@@ -121,7 +122,7 @@ class FalconRotaryEmbedding(nn.Module):
|
|
|
121
122
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
122
123
|
|
|
123
124
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
124
|
-
self.original_inv_freq =
|
|
125
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
125
126
|
|
|
126
127
|
@staticmethod
|
|
127
128
|
def compute_default_rope_parameters(
|
|
@@ -160,7 +161,7 @@ class FalconRotaryEmbedding(nn.Module):
|
|
|
160
161
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
161
162
|
|
|
162
163
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
163
|
-
with
|
|
164
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
164
165
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
165
166
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
166
167
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -520,8 +521,8 @@ class FalconFlashAttention2(FalconAttention):
|
|
|
520
521
|
else torch.get_autocast_gpu_dtype()
|
|
521
522
|
)
|
|
522
523
|
# Handle the case where the model is quantized
|
|
523
|
-
elif hasattr(self.config, "
|
|
524
|
-
target_dtype = self.config.
|
|
524
|
+
elif hasattr(self.config, "quantization_config"):
|
|
525
|
+
target_dtype = self.config.dtype
|
|
525
526
|
else:
|
|
526
527
|
target_dtype = self.query_key_value.weight.dtype
|
|
527
528
|
|
|
@@ -739,6 +740,7 @@ class FalconModel(FalconPreTrainedModel):
|
|
|
739
740
|
output_hidden_states: Optional[bool] = None,
|
|
740
741
|
return_dict: Optional[bool] = None,
|
|
741
742
|
cache_position: Optional[torch.LongTensor] = None,
|
|
743
|
+
**kwargs,
|
|
742
744
|
) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
|
743
745
|
r"""
|
|
744
746
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -1119,6 +1121,7 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
|
|
|
1119
1121
|
output_attentions: Optional[bool] = None,
|
|
1120
1122
|
output_hidden_states: Optional[bool] = None,
|
|
1121
1123
|
return_dict: Optional[bool] = None,
|
|
1124
|
+
**kwargs,
|
|
1122
1125
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
|
|
1123
1126
|
r"""
|
|
1124
1127
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -1243,6 +1246,7 @@ class FalconForTokenClassification(FalconPreTrainedModel):
|
|
|
1243
1246
|
output_attentions: Optional[bool] = None,
|
|
1244
1247
|
output_hidden_states: Optional[bool] = None,
|
|
1245
1248
|
return_dict: Optional[bool] = None,
|
|
1249
|
+
**kwargs,
|
|
1246
1250
|
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
|
1247
1251
|
r"""
|
|
1248
1252
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -1320,6 +1324,7 @@ class FalconForQuestionAnswering(FalconPreTrainedModel):
|
|
|
1320
1324
|
output_attentions: Optional[bool] = None,
|
|
1321
1325
|
output_hidden_states: Optional[bool] = None,
|
|
1322
1326
|
return_dict: Optional[bool] = None,
|
|
1327
|
+
**kwargs,
|
|
1323
1328
|
) -> Union[tuple, QuestionAnsweringModelOutput]:
|
|
1324
1329
|
r"""
|
|
1325
1330
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -36,7 +36,7 @@ from transformers.activations import ACT2FN
|
|
|
36
36
|
from ... import initialization as init
|
|
37
37
|
from ...cache_utils import Cache
|
|
38
38
|
from ...generation import GenerationMixin
|
|
39
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
39
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
40
40
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
|
41
41
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
42
42
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -45,6 +45,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
45
45
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
46
46
|
from ...processing_utils import Unpack
|
|
47
47
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
|
48
|
+
from ...utils.generic import maybe_autocast
|
|
48
49
|
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
|
49
50
|
from .configuration_falcon_h1 import FalconH1Config
|
|
50
51
|
|
|
@@ -240,7 +241,7 @@ class FalconH1RotaryEmbedding(nn.Module):
|
|
|
240
241
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
241
242
|
|
|
242
243
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
243
|
-
self.original_inv_freq =
|
|
244
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
244
245
|
|
|
245
246
|
@staticmethod
|
|
246
247
|
def compute_default_rope_parameters(
|
|
@@ -279,7 +280,7 @@ class FalconH1RotaryEmbedding(nn.Module):
|
|
|
279
280
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
280
281
|
|
|
281
282
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
282
|
-
with
|
|
283
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
283
284
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
284
285
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
285
286
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -361,6 +362,7 @@ def eager_attention_forward(
|
|
|
361
362
|
return attn_output, attn_weights
|
|
362
363
|
|
|
363
364
|
|
|
365
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
364
366
|
class FalconH1Attention(nn.Module):
|
|
365
367
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
366
368
|
|
|
@@ -386,7 +388,6 @@ class FalconH1Attention(nn.Module):
|
|
|
386
388
|
self.o_proj = nn.Linear(
|
|
387
389
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
388
390
|
)
|
|
389
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
390
391
|
self.key_multiplier = config.key_multiplier
|
|
391
392
|
|
|
392
393
|
def forward(
|
|
@@ -1186,26 +1187,6 @@ class FalconH1DecoderLayer(GradientCheckpointingLayer):
|
|
|
1186
1187
|
return outputs
|
|
1187
1188
|
|
|
1188
1189
|
|
|
1189
|
-
@auto_docstring
|
|
1190
|
-
class FalconH1PreTrainedModel(PreTrainedModel):
|
|
1191
|
-
config: FalconH1Config
|
|
1192
|
-
base_model_prefix = "model"
|
|
1193
|
-
supports_gradient_checkpointing = True
|
|
1194
|
-
_no_split_modules = ["FalconH1DecoderLayer"]
|
|
1195
|
-
_skip_keys_device_placement = "past_key_values"
|
|
1196
|
-
_supports_flash_attn = True
|
|
1197
|
-
_supports_sdpa = True
|
|
1198
|
-
_is_stateful = True
|
|
1199
|
-
|
|
1200
|
-
@torch.no_grad()
|
|
1201
|
-
def _init_weights(self, module):
|
|
1202
|
-
super()._init_weights(module)
|
|
1203
|
-
if isinstance(module, FalconH1Mixer):
|
|
1204
|
-
init.ones_(module.dt_bias)
|
|
1205
|
-
init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
|
|
1206
|
-
init.ones_(module.D)
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
1190
|
def compute_mup_vector(config):
|
|
1210
1191
|
"""
|
|
1211
1192
|
Computes the MuP vector based on model configuration.
|
|
@@ -1243,6 +1224,30 @@ def compute_mup_vector(config):
|
|
|
1243
1224
|
return mup_vector
|
|
1244
1225
|
|
|
1245
1226
|
|
|
1227
|
+
@auto_docstring
|
|
1228
|
+
class FalconH1PreTrainedModel(PreTrainedModel):
|
|
1229
|
+
config: FalconH1Config
|
|
1230
|
+
base_model_prefix = "model"
|
|
1231
|
+
supports_gradient_checkpointing = True
|
|
1232
|
+
_no_split_modules = ["FalconH1DecoderLayer"]
|
|
1233
|
+
_skip_keys_device_placement = "past_key_values"
|
|
1234
|
+
_supports_flash_attn = True
|
|
1235
|
+
_supports_sdpa = True
|
|
1236
|
+
_is_stateful = True
|
|
1237
|
+
|
|
1238
|
+
@torch.no_grad()
|
|
1239
|
+
def _init_weights(self, module):
|
|
1240
|
+
super()._init_weights(module)
|
|
1241
|
+
if isinstance(module, FalconH1Mixer):
|
|
1242
|
+
init.ones_(module.dt_bias)
|
|
1243
|
+
init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
|
|
1244
|
+
init.ones_(module.D)
|
|
1245
|
+
elif isinstance(module, FalconH1Model):
|
|
1246
|
+
mup_vector = compute_mup_vector(module.config)
|
|
1247
|
+
for layer in module.layers:
|
|
1248
|
+
init.copy_(layer.mamba.mup_vector, mup_vector)
|
|
1249
|
+
|
|
1250
|
+
|
|
1246
1251
|
@auto_docstring
|
|
1247
1252
|
# Adapted from transformers.models.jamba.modeling_jamba.JambaModel
|
|
1248
1253
|
class FalconH1Model(FalconH1PreTrainedModel):
|
|
@@ -1268,7 +1273,7 @@ class FalconH1Model(FalconH1PreTrainedModel):
|
|
|
1268
1273
|
# Compute the MuP vector once and register it for all layers
|
|
1269
1274
|
mup_vector = compute_mup_vector(config)
|
|
1270
1275
|
for layer in self.layers:
|
|
1271
|
-
layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False)
|
|
1276
|
+
layer.mamba.register_buffer("mup_vector", mup_vector.clone(), persistent=False)
|
|
1272
1277
|
|
|
1273
1278
|
# Initialize weights and apply final processing
|
|
1274
1279
|
self.post_init()
|
|
@@ -1590,6 +1595,7 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin):
|
|
|
1590
1595
|
cache_position=None,
|
|
1591
1596
|
position_ids=None,
|
|
1592
1597
|
use_cache=True,
|
|
1598
|
+
is_first_iteration=False,
|
|
1593
1599
|
**kwargs,
|
|
1594
1600
|
):
|
|
1595
1601
|
# Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
|
|
@@ -1627,7 +1633,7 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin):
|
|
|
1627
1633
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
1628
1634
|
|
|
1629
1635
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
1630
|
-
if inputs_embeds is not None and
|
|
1636
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
1631
1637
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
1632
1638
|
else:
|
|
1633
1639
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
@@ -928,6 +928,10 @@ class FalconH1PreTrainedModel(PreTrainedModel):
|
|
|
928
928
|
init.ones_(module.dt_bias)
|
|
929
929
|
init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
|
|
930
930
|
init.ones_(module.D)
|
|
931
|
+
elif isinstance(module, FalconH1Model):
|
|
932
|
+
mup_vector = compute_mup_vector(module.config)
|
|
933
|
+
for layer in module.layers:
|
|
934
|
+
init.copy_(layer.mamba.mup_vector, mup_vector)
|
|
931
935
|
|
|
932
936
|
|
|
933
937
|
def compute_mup_vector(config):
|
|
@@ -992,7 +996,7 @@ class FalconH1Model(FalconH1PreTrainedModel):
|
|
|
992
996
|
# Compute the MuP vector once and register it for all layers
|
|
993
997
|
mup_vector = compute_mup_vector(config)
|
|
994
998
|
for layer in self.layers:
|
|
995
|
-
layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False)
|
|
999
|
+
layer.mamba.register_buffer("mup_vector", mup_vector.clone(), persistent=False)
|
|
996
1000
|
|
|
997
1001
|
# Initialize weights and apply final processing
|
|
998
1002
|
self.post_init()
|
|
@@ -1298,6 +1302,7 @@ class FalconH1ForCausalLM(LlamaForCausalLM):
|
|
|
1298
1302
|
cache_position=None,
|
|
1299
1303
|
position_ids=None,
|
|
1300
1304
|
use_cache=True,
|
|
1305
|
+
is_first_iteration=False,
|
|
1301
1306
|
**kwargs,
|
|
1302
1307
|
):
|
|
1303
1308
|
# Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
|
|
@@ -1335,7 +1340,7 @@ class FalconH1ForCausalLM(LlamaForCausalLM):
|
|
|
1335
1340
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
1336
1341
|
|
|
1337
1342
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
1338
|
-
if inputs_embeds is not None and
|
|
1343
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
1339
1344
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
1340
1345
|
else:
|
|
1341
1346
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
@@ -31,15 +31,11 @@ from ... import initialization as init
|
|
|
31
31
|
from ...activations import ACT2FN
|
|
32
32
|
from ...configuration_utils import PreTrainedConfig
|
|
33
33
|
from ...generation import GenerationMixin
|
|
34
|
-
from ...integrations
|
|
34
|
+
from ...integrations import lazy_load_kernel
|
|
35
35
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
36
36
|
from ...modeling_utils import PreTrainedModel
|
|
37
37
|
from ...utils import ModelOutput, auto_docstring, logging
|
|
38
|
-
from ...utils.import_utils import
|
|
39
|
-
is_mamba_ssm_available,
|
|
40
|
-
is_mambapy_available,
|
|
41
|
-
is_torchdynamo_compiling,
|
|
42
|
-
)
|
|
38
|
+
from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling
|
|
43
39
|
from .configuration_falcon_mamba import FalconMambaConfig
|
|
44
40
|
|
|
45
41
|
|
|
@@ -48,14 +44,6 @@ if is_mambapy_available():
|
|
|
48
44
|
else:
|
|
49
45
|
pscan = None
|
|
50
46
|
|
|
51
|
-
if is_mamba_ssm_available():
|
|
52
|
-
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
|
53
|
-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
|
54
|
-
|
|
55
|
-
from ...kernels.falcon_mamba import mamba_inner_fn
|
|
56
|
-
else:
|
|
57
|
-
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
|
58
|
-
|
|
59
47
|
|
|
60
48
|
logger = logging.get_logger(__name__)
|
|
61
49
|
|
|
@@ -231,7 +219,27 @@ class FalconMambaMixer(nn.Module):
|
|
|
231
219
|
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
|
|
232
220
|
self.use_bias = config.use_bias
|
|
233
221
|
|
|
222
|
+
global causal_conv1d, causal_conv1d_update, causal_conv1d_fn
|
|
223
|
+
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
224
|
+
causal_conv1d_update, causal_conv1d_fn = (
|
|
225
|
+
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
|
226
|
+
if causal_conv1d is not None
|
|
227
|
+
else (None, None)
|
|
228
|
+
)
|
|
229
|
+
global falcon_mamba_ssm, selective_state_update, selective_scan_fn, falcon_mamba_inner_fn
|
|
230
|
+
falcon_mamba_ssm = lazy_load_kernel("falcon_mamba-ssm")
|
|
231
|
+
selective_state_update, selective_scan_fn, falcon_mamba_inner_fn = (
|
|
232
|
+
(
|
|
233
|
+
falcon_mamba_ssm.selective_state_update,
|
|
234
|
+
falcon_mamba_ssm.selective_scan_fn,
|
|
235
|
+
falcon_mamba_ssm.falcon_mamba_inner_fn,
|
|
236
|
+
)
|
|
237
|
+
if falcon_mamba_ssm is not None
|
|
238
|
+
else (None, None, None)
|
|
239
|
+
)
|
|
240
|
+
|
|
234
241
|
self.warn_slow_implementation()
|
|
242
|
+
|
|
235
243
|
# Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
|
|
236
244
|
self.register_buffer(
|
|
237
245
|
"b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
|
|
@@ -242,14 +250,8 @@ class FalconMambaMixer(nn.Module):
|
|
|
242
250
|
self.rms_eps = config.mixer_rms_eps
|
|
243
251
|
|
|
244
252
|
def warn_slow_implementation(self):
|
|
245
|
-
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
246
|
-
causal_conv1d_update, causal_conv1d_fn = (
|
|
247
|
-
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
|
248
|
-
if causal_conv1d is not None
|
|
249
|
-
else (None, None)
|
|
250
|
-
)
|
|
251
253
|
is_fast_path_available = all(
|
|
252
|
-
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update,
|
|
254
|
+
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, falcon_mamba_inner_fn)
|
|
253
255
|
)
|
|
254
256
|
if not is_fast_path_available:
|
|
255
257
|
if self.use_falcon_mambapy:
|
|
@@ -279,9 +281,8 @@ class FalconMambaMixer(nn.Module):
|
|
|
279
281
|
):
|
|
280
282
|
# 1. Gated MLP's linear projection
|
|
281
283
|
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
|
282
|
-
|
|
283
284
|
if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
|
|
284
|
-
contextualized_states =
|
|
285
|
+
contextualized_states = falcon_mamba_inner_fn(
|
|
285
286
|
projected_states,
|
|
286
287
|
self.conv1d.weight,
|
|
287
288
|
self.conv1d.bias if self.use_conv_bias else None,
|
|
@@ -302,12 +303,6 @@ class FalconMambaMixer(nn.Module):
|
|
|
302
303
|
)
|
|
303
304
|
|
|
304
305
|
else:
|
|
305
|
-
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
306
|
-
causal_conv1d_update, causal_conv1d_fn = (
|
|
307
|
-
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
|
308
|
-
if causal_conv1d is not None
|
|
309
|
-
else (None, None)
|
|
310
|
-
)
|
|
311
306
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
|
312
307
|
|
|
313
308
|
if attention_mask is not None:
|
|
@@ -350,7 +345,7 @@ class FalconMambaMixer(nn.Module):
|
|
|
350
345
|
|
|
351
346
|
# In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
|
|
352
347
|
# at the price of a small overhead.
|
|
353
|
-
if hasattr(self.config, "
|
|
348
|
+
if hasattr(self.config, "quantization_config"):
|
|
354
349
|
discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
|
|
355
350
|
else:
|
|
356
351
|
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
|
|
@@ -502,14 +497,8 @@ class FalconMambaMixer(nn.Module):
|
|
|
502
497
|
cache_position: Optional[torch.LongTensor] = None,
|
|
503
498
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
504
499
|
):
|
|
505
|
-
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
506
|
-
causal_conv1d_update, causal_conv1d_fn = (
|
|
507
|
-
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
|
508
|
-
if causal_conv1d is not None
|
|
509
|
-
else (None, None)
|
|
510
|
-
)
|
|
511
500
|
is_fast_path_available = all(
|
|
512
|
-
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update,
|
|
501
|
+
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, falcon_mamba_inner_fn)
|
|
513
502
|
)
|
|
514
503
|
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not is_torchdynamo_compiling():
|
|
515
504
|
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
|
@@ -624,6 +613,9 @@ class FalconMambaPreTrainedModel(PreTrainedModel):
|
|
|
624
613
|
init.ones_(module.weight)
|
|
625
614
|
elif isinstance(module, nn.Embedding):
|
|
626
615
|
init.normal_(module.weight, std=std)
|
|
616
|
+
if isinstance(module, FalconMambaMixer):
|
|
617
|
+
init.ones_(module.b_c_rms)
|
|
618
|
+
init.ones_(module.dt_rms)
|
|
627
619
|
|
|
628
620
|
|
|
629
621
|
@dataclass
|
|
@@ -703,6 +695,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
|
|
|
703
695
|
return_dict: Optional[bool] = None,
|
|
704
696
|
cache_position: Optional[torch.LongTensor] = None,
|
|
705
697
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
698
|
+
**kwargs,
|
|
706
699
|
) -> Union[tuple, FalconMambaOutput]:
|
|
707
700
|
r"""
|
|
708
701
|
cache_params (`FalconMambaCache`, *optional*):
|
|
@@ -821,6 +814,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
|
|
|
821
814
|
cache_params: Optional[FalconMambaCache] = None,
|
|
822
815
|
cache_position: Optional[torch.LongTensor] = None,
|
|
823
816
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
817
|
+
is_first_iteration: Optional[bool] = False,
|
|
824
818
|
**kwargs,
|
|
825
819
|
):
|
|
826
820
|
# Overwritten -- uses `cache_params` as opposed to `past_key_values`
|
|
@@ -19,9 +19,9 @@ from typing import Optional
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch import nn
|
|
21
21
|
|
|
22
|
-
from ...
|
|
22
|
+
from ... import initialization as init
|
|
23
23
|
from ...utils import auto_docstring, logging
|
|
24
|
-
from ...utils.import_utils import
|
|
24
|
+
from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling
|
|
25
25
|
from ..mamba.configuration_mamba import MambaConfig
|
|
26
26
|
from ..mamba.modeling_mamba import (
|
|
27
27
|
MambaBlock,
|
|
@@ -43,13 +43,13 @@ if is_mambapy_available():
|
|
|
43
43
|
else:
|
|
44
44
|
pscan = None
|
|
45
45
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
46
|
+
selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, falcon_mamba_inner_fn = (
|
|
47
|
+
None,
|
|
48
|
+
None,
|
|
49
|
+
None,
|
|
50
|
+
None,
|
|
51
|
+
None,
|
|
52
|
+
)
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
class FalconMambaConfig(MambaConfig):
|
|
@@ -251,14 +251,8 @@ def rms_forward(hidden_states, variance_epsilon=1e-6):
|
|
|
251
251
|
|
|
252
252
|
class FalconMambaMixer(MambaMixer):
|
|
253
253
|
def warn_slow_implementation(self):
|
|
254
|
-
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
255
|
-
causal_conv1d_update, causal_conv1d_fn = (
|
|
256
|
-
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
|
257
|
-
if causal_conv1d is not None
|
|
258
|
-
else (None, None)
|
|
259
|
-
)
|
|
260
254
|
is_fast_path_available = all(
|
|
261
|
-
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update,
|
|
255
|
+
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, falcon_mamba_inner_fn)
|
|
262
256
|
)
|
|
263
257
|
if not is_fast_path_available:
|
|
264
258
|
if self.use_falcon_mambapy:
|
|
@@ -281,6 +275,7 @@ class FalconMambaMixer(MambaMixer):
|
|
|
281
275
|
|
|
282
276
|
def __init__(self, config: FalconMambaConfig, layer_idx: int):
|
|
283
277
|
super().__init__(config, layer_idx)
|
|
278
|
+
|
|
284
279
|
# Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
|
|
285
280
|
self.register_buffer(
|
|
286
281
|
"b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
|
|
@@ -299,9 +294,8 @@ class FalconMambaMixer(MambaMixer):
|
|
|
299
294
|
):
|
|
300
295
|
# 1. Gated MLP's linear projection
|
|
301
296
|
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
|
302
|
-
|
|
303
297
|
if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
|
|
304
|
-
contextualized_states =
|
|
298
|
+
contextualized_states = falcon_mamba_inner_fn(
|
|
305
299
|
projected_states,
|
|
306
300
|
self.conv1d.weight,
|
|
307
301
|
self.conv1d.bias if self.use_conv_bias else None,
|
|
@@ -322,12 +316,6 @@ class FalconMambaMixer(MambaMixer):
|
|
|
322
316
|
)
|
|
323
317
|
|
|
324
318
|
else:
|
|
325
|
-
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
326
|
-
causal_conv1d_update, causal_conv1d_fn = (
|
|
327
|
-
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
|
328
|
-
if causal_conv1d is not None
|
|
329
|
-
else (None, None)
|
|
330
|
-
)
|
|
331
319
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
|
332
320
|
|
|
333
321
|
if attention_mask is not None:
|
|
@@ -370,7 +358,7 @@ class FalconMambaMixer(MambaMixer):
|
|
|
370
358
|
|
|
371
359
|
# In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
|
|
372
360
|
# at the price of a small overhead.
|
|
373
|
-
if hasattr(self.config, "
|
|
361
|
+
if hasattr(self.config, "quantization_config"):
|
|
374
362
|
discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
|
|
375
363
|
else:
|
|
376
364
|
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
|
|
@@ -521,14 +509,8 @@ class FalconMambaMixer(MambaMixer):
|
|
|
521
509
|
cache_position: Optional[torch.LongTensor] = None,
|
|
522
510
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
523
511
|
):
|
|
524
|
-
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
525
|
-
causal_conv1d_update, causal_conv1d_fn = (
|
|
526
|
-
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
|
527
|
-
if causal_conv1d is not None
|
|
528
|
-
else (None, None)
|
|
529
|
-
)
|
|
530
512
|
is_fast_path_available = all(
|
|
531
|
-
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update,
|
|
513
|
+
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, falcon_mamba_inner_fn)
|
|
532
514
|
)
|
|
533
515
|
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not is_torchdynamo_compiling():
|
|
534
516
|
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
|
@@ -548,7 +530,11 @@ class FalconMambaBlock(MambaBlock):
|
|
|
548
530
|
|
|
549
531
|
@auto_docstring
|
|
550
532
|
class FalconMambaPreTrainedModel(MambaPreTrainedModel):
|
|
551
|
-
|
|
533
|
+
def _init_weights(self, module):
|
|
534
|
+
super()._init_weights(module)
|
|
535
|
+
if isinstance(module, FalconMambaMixer):
|
|
536
|
+
init.ones_(module.b_c_rms)
|
|
537
|
+
init.ones_(module.dt_rms)
|
|
552
538
|
|
|
553
539
|
|
|
554
540
|
class FalconMambaOutput(MambaOutput):
|