transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -42,7 +42,7 @@ from .image_processing_mobilevit import MobileVitImageProcessorKwargs
|
|
|
42
42
|
|
|
43
43
|
@auto_docstring
|
|
44
44
|
class MobileViTImageProcessorFast(BaseImageProcessorFast):
|
|
45
|
-
resample = PILImageResampling.
|
|
45
|
+
resample = PILImageResampling.BICUBIC
|
|
46
46
|
size = {"shortest_edge": 224}
|
|
47
47
|
default_to_square = False
|
|
48
48
|
crop_size = {"height": 256, "width": 256}
|
|
@@ -182,7 +182,6 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
|
|
|
182
182
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
183
183
|
|
|
184
184
|
# Stack all processed images if return_tensors is specified
|
|
185
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
186
185
|
|
|
187
186
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
188
187
|
|
|
@@ -615,6 +615,10 @@ class MobileViTPreTrainedModel(PreTrainedModel):
|
|
|
615
615
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
616
616
|
if module.bias is not None:
|
|
617
617
|
init.zeros_(module.bias)
|
|
618
|
+
if getattr(module, "running_mean", None) is not None:
|
|
619
|
+
init.zeros_(module.running_mean)
|
|
620
|
+
init.ones_(module.running_var)
|
|
621
|
+
init.zeros_(module.num_batches_tracked)
|
|
618
622
|
elif isinstance(module, nn.LayerNorm):
|
|
619
623
|
init.zeros_(module.bias)
|
|
620
624
|
init.ones_(module.weight)
|
|
@@ -659,6 +663,7 @@ class MobileViTModel(MobileViTPreTrainedModel):
|
|
|
659
663
|
pixel_values: Optional[torch.Tensor] = None,
|
|
660
664
|
output_hidden_states: Optional[bool] = None,
|
|
661
665
|
return_dict: Optional[bool] = None,
|
|
666
|
+
**kwargs,
|
|
662
667
|
) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
|
|
663
668
|
output_hidden_states = (
|
|
664
669
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -725,6 +730,7 @@ class MobileViTForImageClassification(MobileViTPreTrainedModel):
|
|
|
725
730
|
output_hidden_states: Optional[bool] = None,
|
|
726
731
|
labels: Optional[torch.Tensor] = None,
|
|
727
732
|
return_dict: Optional[bool] = None,
|
|
733
|
+
**kwargs,
|
|
728
734
|
) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
|
|
729
735
|
r"""
|
|
730
736
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -889,6 +895,7 @@ class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel):
|
|
|
889
895
|
labels: Optional[torch.Tensor] = None,
|
|
890
896
|
output_hidden_states: Optional[bool] = None,
|
|
891
897
|
return_dict: Optional[bool] = None,
|
|
898
|
+
**kwargs,
|
|
892
899
|
) -> Union[tuple, SemanticSegmenterOutput]:
|
|
893
900
|
r"""
|
|
894
901
|
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
|
@@ -582,6 +582,10 @@ class MobileViTV2PreTrainedModel(PreTrainedModel):
|
|
|
582
582
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
583
583
|
if module.bias is not None:
|
|
584
584
|
init.zeros_(module.bias)
|
|
585
|
+
if getattr(module, "running_mean", None) is not None:
|
|
586
|
+
init.zeros_(module.running_mean)
|
|
587
|
+
init.ones_(module.running_var)
|
|
588
|
+
init.zeros_(module.num_batches_tracked)
|
|
585
589
|
elif isinstance(module, nn.GroupNorm):
|
|
586
590
|
init.zeros_(module.bias)
|
|
587
591
|
init.ones_(module.weight)
|
|
@@ -623,6 +627,7 @@ class MobileViTV2Model(MobileViTV2PreTrainedModel):
|
|
|
623
627
|
pixel_values: Optional[torch.Tensor] = None,
|
|
624
628
|
output_hidden_states: Optional[bool] = None,
|
|
625
629
|
return_dict: Optional[bool] = None,
|
|
630
|
+
**kwargs,
|
|
626
631
|
) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
|
|
627
632
|
output_hidden_states = (
|
|
628
633
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -691,6 +696,7 @@ class MobileViTV2ForImageClassification(MobileViTV2PreTrainedModel):
|
|
|
691
696
|
output_hidden_states: Optional[bool] = None,
|
|
692
697
|
labels: Optional[torch.Tensor] = None,
|
|
693
698
|
return_dict: Optional[bool] = None,
|
|
699
|
+
**kwargs,
|
|
694
700
|
) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
|
|
695
701
|
r"""
|
|
696
702
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -858,6 +864,7 @@ class MobileViTV2ForSemanticSegmentation(MobileViTV2PreTrainedModel):
|
|
|
858
864
|
labels: Optional[torch.Tensor] = None,
|
|
859
865
|
output_hidden_states: Optional[bool] = None,
|
|
860
866
|
return_dict: Optional[bool] = None,
|
|
867
|
+
**kwargs,
|
|
861
868
|
) -> Union[tuple, SemanticSegmenterOutput]:
|
|
862
869
|
r"""
|
|
863
870
|
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
|
@@ -45,6 +45,7 @@ from ...modeling_outputs import (
|
|
|
45
45
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
46
46
|
from ...modeling_utils import PreTrainedModel
|
|
47
47
|
from ...utils import auto_docstring, is_flash_attn_2_available, logging
|
|
48
|
+
from ...utils.generic import maybe_autocast
|
|
48
49
|
from ...utils.import_utils import is_triton_available
|
|
49
50
|
from .configuration_modernbert import ModernBertConfig
|
|
50
51
|
|
|
@@ -267,7 +268,7 @@ class ModernBertRotaryEmbedding(nn.Module):
|
|
|
267
268
|
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
|
|
268
269
|
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
|
|
269
270
|
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
|
|
270
|
-
|
|
271
|
+
self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
|
|
271
272
|
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
|
|
272
273
|
|
|
273
274
|
@staticmethod
|
|
@@ -316,7 +317,7 @@ class ModernBertRotaryEmbedding(nn.Module):
|
|
|
316
317
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
317
318
|
|
|
318
319
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
319
|
-
with
|
|
320
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
320
321
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
321
322
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
322
323
|
cos = emb.cos() * attention_scaling
|
|
@@ -676,6 +677,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|
|
676
677
|
init.ones_(module.weight)
|
|
677
678
|
if module.bias is not None:
|
|
678
679
|
init.zeros_(module.bias)
|
|
680
|
+
elif isinstance(module, ModernBertRotaryEmbedding):
|
|
681
|
+
for layer_type in module.layer_types:
|
|
682
|
+
rope_init_fn = module.compute_default_rope_parameters
|
|
683
|
+
if module.rope_type[layer_type] != "default":
|
|
684
|
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
|
685
|
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
686
|
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
687
|
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
688
|
+
elif isinstance(module, ModernBertUnpaddedRotaryEmbedding):
|
|
689
|
+
inv_freq = module._compute_inv_freq()
|
|
690
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
679
691
|
|
|
680
692
|
def _check_and_adjust_attn_implementation(
|
|
681
693
|
self, attn_implementation: Optional[str], is_init_check: bool = False
|
|
@@ -852,6 +864,7 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
|
|
852
864
|
output_attentions: Optional[bool] = None,
|
|
853
865
|
output_hidden_states: Optional[bool] = None,
|
|
854
866
|
return_dict: Optional[bool] = None,
|
|
867
|
+
**kwargs,
|
|
855
868
|
) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
|
|
856
869
|
r"""
|
|
857
870
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1345,6 +1358,7 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
|
|
1345
1358
|
output_attentions: Optional[bool] = None,
|
|
1346
1359
|
output_hidden_states: Optional[bool] = None,
|
|
1347
1360
|
return_dict: Optional[bool] = None,
|
|
1361
|
+
**kwargs,
|
|
1348
1362
|
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
|
1349
1363
|
r"""
|
|
1350
1364
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -35,7 +35,7 @@ from ...modeling_outputs import (
|
|
|
35
35
|
SequenceClassifierOutput,
|
|
36
36
|
TokenClassifierOutput,
|
|
37
37
|
)
|
|
38
|
-
from ...modeling_rope_utils import RopeParameters
|
|
38
|
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
|
|
39
39
|
from ...modeling_utils import PreTrainedModel
|
|
40
40
|
from ...utils import auto_docstring, is_flash_attn_2_available, logging
|
|
41
41
|
from ...utils.import_utils import is_triton_available
|
|
@@ -871,6 +871,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|
|
871
871
|
init.ones_(module.weight)
|
|
872
872
|
if module.bias is not None:
|
|
873
873
|
init.zeros_(module.bias)
|
|
874
|
+
elif isinstance(module, ModernBertRotaryEmbedding):
|
|
875
|
+
for layer_type in module.layer_types:
|
|
876
|
+
rope_init_fn = module.compute_default_rope_parameters
|
|
877
|
+
if module.rope_type[layer_type] != "default":
|
|
878
|
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
|
879
|
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
880
|
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
881
|
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
882
|
+
elif isinstance(module, ModernBertUnpaddedRotaryEmbedding):
|
|
883
|
+
inv_freq = module._compute_inv_freq()
|
|
884
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
874
885
|
|
|
875
886
|
def _check_and_adjust_attn_implementation(
|
|
876
887
|
self, attn_implementation: Optional[str], is_init_check: bool = False
|
|
@@ -975,6 +986,7 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
|
|
975
986
|
output_attentions: Optional[bool] = None,
|
|
976
987
|
output_hidden_states: Optional[bool] = None,
|
|
977
988
|
return_dict: Optional[bool] = None,
|
|
989
|
+
**kwargs,
|
|
978
990
|
) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
|
|
979
991
|
r"""
|
|
980
992
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1468,6 +1480,7 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
|
|
1468
1480
|
output_attentions: Optional[bool] = None,
|
|
1469
1481
|
output_hidden_states: Optional[bool] = None,
|
|
1470
1482
|
return_dict: Optional[bool] = None,
|
|
1483
|
+
**kwargs,
|
|
1471
1484
|
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
|
1472
1485
|
r"""
|
|
1473
1486
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -39,7 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
39
39
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
40
40
|
from ...processing_utils import Unpack
|
|
41
41
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
42
|
-
from ...utils.generic import check_model_inputs
|
|
42
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
43
43
|
from .configuration_modernbert_decoder import ModernBertDecoderConfig
|
|
44
44
|
|
|
45
45
|
|
|
@@ -119,7 +119,7 @@ class ModernBertDecoderRotaryEmbedding(nn.Module):
|
|
|
119
119
|
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
|
|
120
120
|
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
|
|
121
121
|
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
|
|
122
|
-
|
|
122
|
+
self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
|
|
123
123
|
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
|
|
124
124
|
|
|
125
125
|
@staticmethod
|
|
@@ -168,7 +168,7 @@ class ModernBertDecoderRotaryEmbedding(nn.Module):
|
|
|
168
168
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
169
169
|
|
|
170
170
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
171
|
-
with
|
|
171
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
172
172
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
173
173
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
174
174
|
cos = emb.cos() * attention_scaling
|
|
@@ -342,7 +342,7 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer):
|
|
|
342
342
|
attention_mask: Optional[torch.Tensor] = None,
|
|
343
343
|
past_key_values: Optional[Cache] = None,
|
|
344
344
|
cache_position: Optional[torch.LongTensor] = None,
|
|
345
|
-
**kwargs,
|
|
345
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
346
346
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
347
347
|
residual = hidden_states
|
|
348
348
|
hidden_states = self.attn_norm(hidden_states)
|
|
@@ -443,6 +443,14 @@ class ModernBertDecoderPreTrainedModel(PreTrainedModel):
|
|
|
443
443
|
init.ones_(module.weight)
|
|
444
444
|
if module.bias is not None:
|
|
445
445
|
init.zeros_(module.bias)
|
|
446
|
+
elif isinstance(module, ModernBertDecoderRotaryEmbedding):
|
|
447
|
+
for layer_type in module.layer_types:
|
|
448
|
+
rope_init_fn = module.compute_default_rope_parameters
|
|
449
|
+
if module.rope_type[layer_type] != "default":
|
|
450
|
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
|
451
|
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
452
|
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
453
|
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
446
454
|
|
|
447
455
|
|
|
448
456
|
@auto_docstring
|
|
@@ -477,7 +485,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
|
|
477
485
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
478
486
|
use_cache: Optional[bool] = None,
|
|
479
487
|
cache_position: Optional[torch.LongTensor] = None,
|
|
480
|
-
**kwargs,
|
|
488
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
481
489
|
) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
|
|
482
490
|
if (input_ids is None) == (inputs_embeds is None):
|
|
483
491
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
@@ -489,7 +497,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
|
|
489
497
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
|
490
498
|
|
|
491
499
|
# Handle past_key_values and cache setup
|
|
492
|
-
if use_cache and past_key_values is None
|
|
500
|
+
if use_cache and past_key_values is None:
|
|
493
501
|
past_key_values = DynamicCache(config=self.config)
|
|
494
502
|
|
|
495
503
|
if cache_position is None:
|
|
@@ -527,13 +535,12 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
|
|
527
535
|
for layer_type in self.config.layer_types:
|
|
528
536
|
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
|
529
537
|
|
|
530
|
-
for
|
|
538
|
+
for decoder_layer in self.layers:
|
|
531
539
|
hidden_states = decoder_layer(
|
|
532
540
|
hidden_states,
|
|
533
541
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
|
534
542
|
position_embeddings=position_embeddings[decoder_layer.attention_type],
|
|
535
543
|
past_key_values=past_key_values,
|
|
536
|
-
use_cache=use_cache,
|
|
537
544
|
cache_position=cache_position,
|
|
538
545
|
position_ids=position_ids,
|
|
539
546
|
**kwargs,
|
|
@@ -583,7 +590,7 @@ class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationM
|
|
|
583
590
|
labels: Optional[torch.LongTensor] = None,
|
|
584
591
|
use_cache: Optional[bool] = None,
|
|
585
592
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
586
|
-
**kwargs,
|
|
593
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
587
594
|
) -> Union[tuple, CausalLMOutputWithPast]:
|
|
588
595
|
r"""
|
|
589
596
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -686,7 +693,7 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
|
|
|
686
693
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
687
694
|
labels: Optional[torch.LongTensor] = None,
|
|
688
695
|
use_cache: Optional[bool] = None,
|
|
689
|
-
**kwargs,
|
|
696
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
690
697
|
) -> Union[tuple, SequenceClassifierOutputWithPast]:
|
|
691
698
|
r"""
|
|
692
699
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -28,7 +28,7 @@ from ...generation import GenerationMixin
|
|
|
28
28
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
29
29
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
30
30
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
|
31
|
-
from ...modeling_rope_utils import RopeParameters
|
|
31
|
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
|
|
32
32
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
33
33
|
from ...processing_utils import Unpack
|
|
34
34
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
@@ -394,7 +394,7 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer):
|
|
|
394
394
|
attention_mask: Optional[torch.Tensor] = None,
|
|
395
395
|
past_key_values: Optional[Cache] = None,
|
|
396
396
|
cache_position: Optional[torch.LongTensor] = None,
|
|
397
|
-
**kwargs,
|
|
397
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
398
398
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
399
399
|
residual = hidden_states
|
|
400
400
|
hidden_states = self.attn_norm(hidden_states)
|
|
@@ -482,6 +482,14 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
|
|
|
482
482
|
init.ones_(module.weight)
|
|
483
483
|
if module.bias is not None:
|
|
484
484
|
init.zeros_(module.bias)
|
|
485
|
+
elif isinstance(module, ModernBertDecoderRotaryEmbedding):
|
|
486
|
+
for layer_type in module.layer_types:
|
|
487
|
+
rope_init_fn = module.compute_default_rope_parameters
|
|
488
|
+
if module.rope_type[layer_type] != "default":
|
|
489
|
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
|
490
|
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
491
|
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
492
|
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
485
493
|
|
|
486
494
|
def _check_and_adjust_attn_implementation(self, attn_implementation, is_init_check):
|
|
487
495
|
raise AttributeError("No need to inherit!")
|
|
@@ -525,7 +533,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
|
|
525
533
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
526
534
|
use_cache: Optional[bool] = None,
|
|
527
535
|
cache_position: Optional[torch.LongTensor] = None,
|
|
528
|
-
**kwargs,
|
|
536
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
529
537
|
) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
|
|
530
538
|
if (input_ids is None) == (inputs_embeds is None):
|
|
531
539
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
@@ -537,7 +545,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
|
|
537
545
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
|
538
546
|
|
|
539
547
|
# Handle past_key_values and cache setup
|
|
540
|
-
if use_cache and past_key_values is None
|
|
548
|
+
if use_cache and past_key_values is None:
|
|
541
549
|
past_key_values = DynamicCache(config=self.config)
|
|
542
550
|
|
|
543
551
|
if cache_position is None:
|
|
@@ -575,13 +583,12 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
|
|
575
583
|
for layer_type in self.config.layer_types:
|
|
576
584
|
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
|
577
585
|
|
|
578
|
-
for
|
|
586
|
+
for decoder_layer in self.layers:
|
|
579
587
|
hidden_states = decoder_layer(
|
|
580
588
|
hidden_states,
|
|
581
589
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
|
582
590
|
position_embeddings=position_embeddings[decoder_layer.attention_type],
|
|
583
591
|
past_key_values=past_key_values,
|
|
584
|
-
use_cache=use_cache,
|
|
585
592
|
cache_position=cache_position,
|
|
586
593
|
position_ids=position_ids,
|
|
587
594
|
**kwargs,
|
|
@@ -631,7 +638,7 @@ class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationM
|
|
|
631
638
|
labels: Optional[torch.LongTensor] = None,
|
|
632
639
|
use_cache: Optional[bool] = None,
|
|
633
640
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
634
|
-
**kwargs,
|
|
641
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
635
642
|
) -> Union[tuple, CausalLMOutputWithPast]:
|
|
636
643
|
r"""
|
|
637
644
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -734,7 +741,7 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
|
|
|
734
741
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
735
742
|
labels: Optional[torch.LongTensor] = None,
|
|
736
743
|
use_cache: Optional[bool] = None,
|
|
737
|
-
**kwargs,
|
|
744
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
738
745
|
) -> Union[tuple, SequenceClassifierOutputWithPast]:
|
|
739
746
|
r"""
|
|
740
747
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -30,6 +30,7 @@ from transformers.utils.generic import OutputRecorder, check_model_inputs
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
+
from ...integrations import use_kernelized_func
|
|
33
34
|
from ...masking_utils import create_causal_mask
|
|
34
35
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
|
35
36
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
@@ -45,6 +46,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
45
46
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
46
47
|
from ...processing_utils import Unpack
|
|
47
48
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
49
|
+
from ...utils.generic import maybe_autocast
|
|
48
50
|
from .configuration_moonshine import MoonshineConfig
|
|
49
51
|
|
|
50
52
|
|
|
@@ -96,7 +98,7 @@ class MoonshineRotaryEmbedding(nn.Module):
|
|
|
96
98
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
97
99
|
|
|
98
100
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
99
|
-
self.original_inv_freq =
|
|
101
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
100
102
|
|
|
101
103
|
@staticmethod
|
|
102
104
|
def compute_default_rope_parameters(
|
|
@@ -137,7 +139,7 @@ class MoonshineRotaryEmbedding(nn.Module):
|
|
|
137
139
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
138
140
|
|
|
139
141
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
140
|
-
with
|
|
142
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
141
143
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
142
144
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
143
145
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -233,6 +235,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
233
235
|
return q_embed, k_embed
|
|
234
236
|
|
|
235
237
|
|
|
238
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
236
239
|
class MoonshineAttention(nn.Module):
|
|
237
240
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
238
241
|
|
|
@@ -264,7 +267,6 @@ class MoonshineAttention(nn.Module):
|
|
|
264
267
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
265
268
|
)
|
|
266
269
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
|
267
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
268
270
|
|
|
269
271
|
# Pad head dimension to the next specified multiple.
|
|
270
272
|
if self.config.pad_head_dim_to_multiple_of is not None:
|
|
@@ -34,6 +34,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast,
|
|
|
34
34
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
35
35
|
from ...modeling_utils import PreTrainedModel
|
|
36
36
|
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
|
37
|
+
from ...utils.generic import maybe_autocast
|
|
37
38
|
from ..auto.modeling_auto import AutoModel
|
|
38
39
|
from .configuration_moshi import MoshiConfig, MoshiDepthConfig
|
|
39
40
|
|
|
@@ -288,7 +289,7 @@ class MoshiRotaryEmbedding(nn.Module):
|
|
|
288
289
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
289
290
|
|
|
290
291
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
291
|
-
self.original_inv_freq =
|
|
292
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
292
293
|
|
|
293
294
|
@staticmethod
|
|
294
295
|
def compute_default_rope_parameters(
|
|
@@ -327,7 +328,7 @@ class MoshiRotaryEmbedding(nn.Module):
|
|
|
327
328
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
328
329
|
|
|
329
330
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
330
|
-
with
|
|
331
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
331
332
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
332
333
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
333
334
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -608,8 +609,8 @@ class MoshiFlashAttention2(MoshiAttention):
|
|
|
608
609
|
else torch.get_autocast_gpu_dtype()
|
|
609
610
|
)
|
|
610
611
|
# Handle the case where the model is quantized
|
|
611
|
-
elif hasattr(self.config, "
|
|
612
|
-
target_dtype = self.config.
|
|
612
|
+
elif hasattr(self.config, "quantization_config"):
|
|
613
|
+
target_dtype = self.config.dtype
|
|
613
614
|
else:
|
|
614
615
|
target_dtype = self.q_proj.weight.dtype
|
|
615
616
|
|
|
@@ -868,6 +869,8 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
|
|
|
868
869
|
self.gradient_checkpointing = False
|
|
869
870
|
self.config = config
|
|
870
871
|
|
|
872
|
+
self.post_init()
|
|
873
|
+
|
|
871
874
|
def forward(
|
|
872
875
|
self,
|
|
873
876
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -882,6 +885,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
|
|
|
882
885
|
position_ids: Optional[torch.LongTensor] = None,
|
|
883
886
|
labels: Optional[torch.LongTensor] = None,
|
|
884
887
|
cache_position: Optional[torch.LongTensor] = None,
|
|
888
|
+
**kwargs,
|
|
885
889
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
886
890
|
"""
|
|
887
891
|
Args:
|
|
@@ -957,7 +961,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
|
|
|
957
961
|
)
|
|
958
962
|
use_cache = False
|
|
959
963
|
|
|
960
|
-
if use_cache and past_key_values is None
|
|
964
|
+
if use_cache and past_key_values is None:
|
|
961
965
|
past_key_values = DynamicCache(config=self.config)
|
|
962
966
|
|
|
963
967
|
past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
@@ -1228,6 +1232,7 @@ class MoshiModel(MoshiPreTrainedModel):
|
|
|
1228
1232
|
output_hidden_states: Optional[bool] = None,
|
|
1229
1233
|
return_dict: Optional[bool] = None,
|
|
1230
1234
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1235
|
+
**kwargs,
|
|
1231
1236
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
1232
1237
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1233
1238
|
output_hidden_states = (
|
|
@@ -2175,6 +2180,7 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
|
|
|
2175
2180
|
user_delay_pattern_mask=None,
|
|
2176
2181
|
moshi_delay_pattern_mask=None,
|
|
2177
2182
|
kwargs_depth_decoder=None,
|
|
2183
|
+
is_first_iteration=False,
|
|
2178
2184
|
blank_user_audio_codes: Optional[torch.FloatTensor] = None,
|
|
2179
2185
|
**kwargs,
|
|
2180
2186
|
):
|
|
@@ -2186,49 +2192,21 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
|
|
|
2186
2192
|
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
|
2187
2193
|
# (we can't check exception 3 while compiling)
|
|
2188
2194
|
|
|
2189
|
-
|
|
2190
|
-
|
|
2191
|
-
|
|
2192
|
-
|
|
2193
|
-
|
|
2194
|
-
|
|
2195
|
-
|
|
2196
|
-
|
|
2197
|
-
|
|
2198
|
-
|
|
2199
|
-
|
|
2200
|
-
|
|
2201
|
-
|
|
2202
|
-
|
|
2203
|
-
|
|
2204
|
-
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
|
2205
|
-
if model_inputs["inputs_embeds"] is not None:
|
|
2206
|
-
batch_size, sequence_length, _ = inputs_embeds.shape
|
|
2207
|
-
device = inputs_embeds.device
|
|
2208
|
-
else:
|
|
2209
|
-
batch_size, sequence_length = input_ids.shape
|
|
2210
|
-
device = input_ids.device
|
|
2211
|
-
|
|
2212
|
-
attention_mask = self.decoder.model._prepare_4d_causal_attention_mask_with_cache_position(
|
|
2213
|
-
attention_mask,
|
|
2214
|
-
sequence_length=sequence_length,
|
|
2215
|
-
target_length=past_key_values.get_max_cache_shape(),
|
|
2216
|
-
dtype=self.decoder.lm_head.weight.dtype,
|
|
2217
|
-
device=device,
|
|
2218
|
-
cache_position=cache_position,
|
|
2219
|
-
batch_size=batch_size,
|
|
2220
|
-
config=self.config,
|
|
2221
|
-
past_key_values=past_key_values,
|
|
2222
|
-
)
|
|
2223
|
-
|
|
2224
|
-
model_inputs.update(
|
|
2225
|
-
{
|
|
2226
|
-
"position_ids": position_ids,
|
|
2227
|
-
"past_key_values": past_key_values,
|
|
2228
|
-
"use_cache": use_cache,
|
|
2229
|
-
"attention_mask": attention_mask,
|
|
2230
|
-
"cache_position": cache_position,
|
|
2231
|
-
}
|
|
2195
|
+
model_inputs = super().prepare_inputs_for_generation(
|
|
2196
|
+
input_ids,
|
|
2197
|
+
past_key_values=past_key_values,
|
|
2198
|
+
attention_mask=attention_mask,
|
|
2199
|
+
inputs_embeds=inputs_embeds,
|
|
2200
|
+
cache_position=cache_position,
|
|
2201
|
+
position_ids=position_ids,
|
|
2202
|
+
use_cache=use_cache,
|
|
2203
|
+
logits_to_keep=logits_to_keep,
|
|
2204
|
+
user_delay_pattern_mask=user_delay_pattern_mask,
|
|
2205
|
+
moshi_delay_pattern_mask=moshi_delay_pattern_mask,
|
|
2206
|
+
kwargs_depth_decoder=kwargs_depth_decoder,
|
|
2207
|
+
is_first_iteration=is_first_iteration,
|
|
2208
|
+
blank_user_audio_codes=blank_user_audio_codes,
|
|
2209
|
+
**kwargs,
|
|
2232
2210
|
)
|
|
2233
2211
|
|
|
2234
2212
|
# 2. Now that everything is prepared, generate audio_codes using the depth decoder
|
|
@@ -2267,11 +2245,6 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
|
|
|
2267
2245
|
model_inputs["input_ids"] = None
|
|
2268
2246
|
model_inputs["inputs_embeds"] = inputs_embeds
|
|
2269
2247
|
|
|
2270
|
-
# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
|
2271
|
-
for key, value in kwargs.items():
|
|
2272
|
-
if key not in model_inputs:
|
|
2273
|
-
model_inputs[key] = value
|
|
2274
|
-
|
|
2275
2248
|
return model_inputs
|
|
2276
2249
|
|
|
2277
2250
|
def _update_model_kwargs_for_generation(
|
|
@@ -52,6 +52,8 @@ class MPNetPreTrainedModel(PreTrainedModel):
|
|
|
52
52
|
super()._init_weights(module)
|
|
53
53
|
if isinstance(module, MPNetLMHead):
|
|
54
54
|
init.zeros_(module.bias)
|
|
55
|
+
elif isinstance(module, MPNetEmbeddings):
|
|
56
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
55
57
|
|
|
56
58
|
|
|
57
59
|
class MPNetEmbeddings(nn.Module):
|
|
@@ -488,6 +490,7 @@ class MPNetForMaskedLM(MPNetPreTrainedModel):
|
|
|
488
490
|
output_attentions: Optional[bool] = None,
|
|
489
491
|
output_hidden_states: Optional[bool] = None,
|
|
490
492
|
return_dict: Optional[bool] = None,
|
|
493
|
+
**kwargs,
|
|
491
494
|
) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
|
|
492
495
|
r"""
|
|
493
496
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -577,6 +580,7 @@ class MPNetForSequenceClassification(MPNetPreTrainedModel):
|
|
|
577
580
|
output_attentions: Optional[bool] = None,
|
|
578
581
|
output_hidden_states: Optional[bool] = None,
|
|
579
582
|
return_dict: Optional[bool] = None,
|
|
583
|
+
**kwargs,
|
|
580
584
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
|
|
581
585
|
r"""
|
|
582
586
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -656,6 +660,7 @@ class MPNetForMultipleChoice(MPNetPreTrainedModel):
|
|
|
656
660
|
output_attentions: Optional[bool] = None,
|
|
657
661
|
output_hidden_states: Optional[bool] = None,
|
|
658
662
|
return_dict: Optional[bool] = None,
|
|
663
|
+
**kwargs,
|
|
659
664
|
) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
|
660
665
|
r"""
|
|
661
666
|
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
|
@@ -748,6 +753,7 @@ class MPNetForTokenClassification(MPNetPreTrainedModel):
|
|
|
748
753
|
output_attentions: Optional[bool] = None,
|
|
749
754
|
output_hidden_states: Optional[bool] = None,
|
|
750
755
|
return_dict: Optional[bool] = None,
|
|
756
|
+
**kwargs,
|
|
751
757
|
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
|
752
758
|
r"""
|
|
753
759
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -831,6 +837,7 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel):
|
|
|
831
837
|
output_attentions: Optional[bool] = None,
|
|
832
838
|
output_hidden_states: Optional[bool] = None,
|
|
833
839
|
return_dict: Optional[bool] = None,
|
|
840
|
+
**kwargs,
|
|
834
841
|
) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
|
835
842
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
836
843
|
|