transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -21,18 +21,16 @@
|
|
|
21
21
|
|
|
22
22
|
import math
|
|
23
23
|
from collections.abc import Callable
|
|
24
|
-
from contextlib import nullcontext
|
|
25
24
|
from typing import Optional
|
|
26
25
|
|
|
27
26
|
import torch
|
|
28
|
-
import torch.nn.functional as F
|
|
29
27
|
from torch import nn
|
|
30
28
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
31
29
|
|
|
32
30
|
from ... import initialization as init
|
|
33
31
|
from ...activations import ACT2FN
|
|
34
|
-
from ...integrations import use_kernel_func_from_hub
|
|
35
|
-
from ...
|
|
32
|
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
|
33
|
+
from ...masking_utils import create_bidirectional_mask, create_bidirectional_sliding_window_mask
|
|
36
34
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
37
35
|
from ...modeling_outputs import (
|
|
38
36
|
BaseModelOutput,
|
|
@@ -43,158 +41,13 @@ from ...modeling_outputs import (
|
|
|
43
41
|
TokenClassifierOutput,
|
|
44
42
|
)
|
|
45
43
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
46
|
-
from ...modeling_utils import PreTrainedModel
|
|
47
|
-
from ...
|
|
48
|
-
from ...utils
|
|
49
|
-
from ...utils.
|
|
44
|
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
45
|
+
from ...processing_utils import Unpack
|
|
46
|
+
from ...utils import TransformersKwargs, auto_docstring
|
|
47
|
+
from ...utils.generic import can_return_tuple, check_model_inputs, maybe_autocast
|
|
50
48
|
from .configuration_modernbert import ModernBertConfig
|
|
51
49
|
|
|
52
50
|
|
|
53
|
-
if is_flash_attn_2_available():
|
|
54
|
-
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
|
55
|
-
from flash_attn.layers.rotary import RotaryEmbedding
|
|
56
|
-
from flash_attn.ops.triton.rotary import apply_rotary
|
|
57
|
-
else:
|
|
58
|
-
RotaryEmbedding = object
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
logger = logging.get_logger(__name__)
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class ApplyRotaryEmbUnpad(torch.autograd.Function):
|
|
65
|
-
@staticmethod
|
|
66
|
-
def forward(
|
|
67
|
-
ctx,
|
|
68
|
-
qkv,
|
|
69
|
-
cos,
|
|
70
|
-
sin,
|
|
71
|
-
cu_seqlens: torch.Tensor | None = None,
|
|
72
|
-
max_seqlen: int | None = None,
|
|
73
|
-
):
|
|
74
|
-
# (total_nnz, 3, nheads, headdim)
|
|
75
|
-
qkv = qkv.contiguous()
|
|
76
|
-
total_nnz, _three, _nheads, headdim = qkv.shape
|
|
77
|
-
# We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
|
|
78
|
-
# we get the same tensor
|
|
79
|
-
# qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
|
|
80
|
-
qk = qkv[:, :2].view(total_nnz, -1, headdim)
|
|
81
|
-
apply_rotary(
|
|
82
|
-
qk,
|
|
83
|
-
cos,
|
|
84
|
-
sin,
|
|
85
|
-
seqlen_offsets=0,
|
|
86
|
-
cu_seqlens=cu_seqlens,
|
|
87
|
-
max_seqlen=max_seqlen,
|
|
88
|
-
interleaved=False,
|
|
89
|
-
inplace=True,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
ctx.save_for_backward(cos, sin, cu_seqlens)
|
|
93
|
-
ctx.max_seqlen = max_seqlen
|
|
94
|
-
return qkv
|
|
95
|
-
|
|
96
|
-
@staticmethod
|
|
97
|
-
def backward(ctx, do):
|
|
98
|
-
cos, sin, cu_seqlens = ctx.saved_tensors
|
|
99
|
-
do = do.contiguous()
|
|
100
|
-
total_nnz, _three, _nheads, headdim = do.shape
|
|
101
|
-
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
|
|
102
|
-
# we get the same tensor
|
|
103
|
-
dqk = do[:, :2].view(total_nnz, -1, headdim)
|
|
104
|
-
apply_rotary(
|
|
105
|
-
dqk,
|
|
106
|
-
cos,
|
|
107
|
-
sin,
|
|
108
|
-
seqlen_offsets=0,
|
|
109
|
-
cu_seqlens=cu_seqlens,
|
|
110
|
-
max_seqlen=ctx.max_seqlen,
|
|
111
|
-
interleaved=False,
|
|
112
|
-
inplace=True,
|
|
113
|
-
conjugate=True,
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
return do, None, None, None, None, None, None
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def apply_rotary_unpadded(
|
|
120
|
-
qkv,
|
|
121
|
-
cos,
|
|
122
|
-
sin,
|
|
123
|
-
cu_seqlens: torch.Tensor | None = None,
|
|
124
|
-
max_seqlen: int | None = None,
|
|
125
|
-
):
|
|
126
|
-
"""
|
|
127
|
-
Arguments:
|
|
128
|
-
qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
|
|
129
|
-
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
|
130
|
-
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
|
131
|
-
of 1st half and 2nd half (GPT-NeoX style).
|
|
132
|
-
inplace: if True, apply rotary embedding in-place.
|
|
133
|
-
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
|
134
|
-
Most commonly used in inference when we have KV cache.
|
|
135
|
-
cu_seqlens: (batch + 1,) or None
|
|
136
|
-
max_seqlen: int
|
|
137
|
-
Return:
|
|
138
|
-
out: (total_nnz, dim)
|
|
139
|
-
rotary_dim must be <= headdim
|
|
140
|
-
Apply rotary embedding to the first rotary_dim of x.
|
|
141
|
-
"""
|
|
142
|
-
return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
|
|
146
|
-
"""
|
|
147
|
-
The rotary position embeddings applied directly to unpadded sequences.
|
|
148
|
-
"""
|
|
149
|
-
|
|
150
|
-
def __init__(
|
|
151
|
-
self,
|
|
152
|
-
dim: int,
|
|
153
|
-
base: float = 10000.0,
|
|
154
|
-
max_seqlen: int | None = None,
|
|
155
|
-
device: torch.device | None = None,
|
|
156
|
-
dtype: torch.dtype | None = None,
|
|
157
|
-
):
|
|
158
|
-
"""
|
|
159
|
-
max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
|
|
160
|
-
up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
|
|
161
|
-
the cos_sin_cache will be recomputed during the forward pass.
|
|
162
|
-
"""
|
|
163
|
-
super().__init__(dim=dim, base=base, device=device, interleaved=False)
|
|
164
|
-
self.max_seqlen = max_seqlen
|
|
165
|
-
|
|
166
|
-
if max_seqlen is not None and device is not None and dtype is not None:
|
|
167
|
-
self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
|
|
168
|
-
|
|
169
|
-
def forward(
|
|
170
|
-
self,
|
|
171
|
-
qkv: torch.Tensor,
|
|
172
|
-
cu_seqlens: torch.Tensor,
|
|
173
|
-
max_seqlen: int | None = None,
|
|
174
|
-
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
175
|
-
"""
|
|
176
|
-
Apply rotary embedding *inplace* to qkv.
|
|
177
|
-
qkv: (total_nnz, 3, nheads, headdim)
|
|
178
|
-
cu_seqlens: (batch + 1,) cumulative sequence lengths
|
|
179
|
-
max_seqlen: int max seq length in the batch
|
|
180
|
-
"""
|
|
181
|
-
if max_seqlen is not None:
|
|
182
|
-
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
|
183
|
-
|
|
184
|
-
qkv = apply_rotary_unpadded(
|
|
185
|
-
qkv,
|
|
186
|
-
self._cos_cached,
|
|
187
|
-
self._sin_cached,
|
|
188
|
-
cu_seqlens=cu_seqlens,
|
|
189
|
-
max_seqlen=max_seqlen,
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
return qkv
|
|
193
|
-
|
|
194
|
-
def extra_repr(self) -> str:
|
|
195
|
-
return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
|
|
196
|
-
|
|
197
|
-
|
|
198
51
|
class ModernBertEmbeddings(nn.Module):
|
|
199
52
|
"""
|
|
200
53
|
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
|
@@ -207,21 +60,13 @@ class ModernBertEmbeddings(nn.Module):
|
|
|
207
60
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
208
61
|
self.drop = nn.Dropout(config.embedding_dropout)
|
|
209
62
|
|
|
210
|
-
@torch.compile(dynamic=True)
|
|
211
|
-
def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
|
212
|
-
return self.drop(self.norm(self.tok_embeddings(input_ids)))
|
|
213
|
-
|
|
214
63
|
def forward(
|
|
215
64
|
self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.Tensor | None = None
|
|
216
65
|
) -> torch.Tensor:
|
|
217
66
|
if inputs_embeds is not None:
|
|
218
67
|
hidden_states = self.drop(self.norm(inputs_embeds))
|
|
219
68
|
else:
|
|
220
|
-
hidden_states = (
|
|
221
|
-
self.compiled_embeddings(input_ids)
|
|
222
|
-
if self.config.reference_compile
|
|
223
|
-
else self.drop(self.norm(self.tok_embeddings(input_ids)))
|
|
224
|
-
)
|
|
69
|
+
hidden_states = self.drop(self.norm(self.tok_embeddings(input_ids)))
|
|
225
70
|
return hidden_states
|
|
226
71
|
|
|
227
72
|
|
|
@@ -326,6 +171,29 @@ class ModernBertRotaryEmbedding(nn.Module):
|
|
|
326
171
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
327
172
|
|
|
328
173
|
|
|
174
|
+
def eager_attention_forward(
|
|
175
|
+
module: nn.Module,
|
|
176
|
+
query: torch.Tensor,
|
|
177
|
+
key: torch.Tensor,
|
|
178
|
+
value: torch.Tensor,
|
|
179
|
+
attention_mask: torch.Tensor | None,
|
|
180
|
+
scaling: float,
|
|
181
|
+
dropout: float = 0.0,
|
|
182
|
+
**kwargs,
|
|
183
|
+
):
|
|
184
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
185
|
+
if attention_mask is not None:
|
|
186
|
+
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
187
|
+
attn_weights = attn_weights + causal_mask
|
|
188
|
+
|
|
189
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
|
190
|
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
191
|
+
|
|
192
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
193
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
194
|
+
return attn_output, attn_weights
|
|
195
|
+
|
|
196
|
+
|
|
329
197
|
def rotate_half(x):
|
|
330
198
|
"""Rotates half the hidden dims of the input."""
|
|
331
199
|
x1 = x[..., : x.shape[-1] // 2]
|
|
@@ -352,137 +220,15 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
|
|
352
220
|
Returns:
|
|
353
221
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
354
222
|
"""
|
|
223
|
+
original_dtype = q.dtype
|
|
355
224
|
cos = cos.unsqueeze(unsqueeze_dim)
|
|
356
225
|
sin = sin.unsqueeze(unsqueeze_dim)
|
|
357
|
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
358
|
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
359
|
-
return q_embed, k_embed
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
def eager_attention_forward(
|
|
363
|
-
module: "ModernBertAttention",
|
|
364
|
-
qkv: torch.Tensor,
|
|
365
|
-
attention_mask: torch.Tensor,
|
|
366
|
-
sliding_window_mask: torch.Tensor,
|
|
367
|
-
position_ids: torch.LongTensor | None,
|
|
368
|
-
local_attention: tuple[int, int],
|
|
369
|
-
bs: int,
|
|
370
|
-
dim: int,
|
|
371
|
-
position_embeddings: torch.Tensor,
|
|
372
|
-
output_attentions: bool | None = False,
|
|
373
|
-
**_kwargs,
|
|
374
|
-
) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
|
|
375
|
-
# qkv: [batch_size, seqlen, 3, nheads, headdim]
|
|
376
|
-
cos, sin = position_embeddings
|
|
377
|
-
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
|
|
378
|
-
# query, key, value: [batch_size, heads, seq_len, head_dim]
|
|
379
|
-
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
|
380
|
-
|
|
381
|
-
scale = module.head_dim**-0.5
|
|
382
|
-
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
|
|
383
|
-
|
|
384
|
-
if local_attention != (-1, -1):
|
|
385
|
-
attention_mask = sliding_window_mask
|
|
386
|
-
|
|
387
|
-
attn_weights = attn_weights + attention_mask
|
|
388
|
-
|
|
389
|
-
# upcast attention to fp32
|
|
390
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
|
391
|
-
attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
|
|
392
|
-
attn_output = torch.matmul(attn_weights, value)
|
|
393
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
394
|
-
attn_output = attn_output.view(bs, -1, dim)
|
|
395
|
-
if output_attentions:
|
|
396
|
-
return (attn_output, attn_weights)
|
|
397
|
-
return (attn_output,)
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
def flash_attention_forward(
|
|
401
|
-
module: "ModernBertAttention",
|
|
402
|
-
qkv: torch.Tensor,
|
|
403
|
-
rotary_emb: ModernBertUnpaddedRotaryEmbedding,
|
|
404
|
-
cu_seqlens: torch.Tensor,
|
|
405
|
-
max_seqlen: int,
|
|
406
|
-
local_attention: tuple[int, int],
|
|
407
|
-
bs: int,
|
|
408
|
-
dim: int,
|
|
409
|
-
target_dtype: torch.dtype = torch.bfloat16,
|
|
410
|
-
**_kwargs,
|
|
411
|
-
) -> tuple[torch.Tensor]:
|
|
412
|
-
# (total_seqlen, 3, nheads, headdim)
|
|
413
|
-
qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
|
414
|
-
|
|
415
|
-
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
|
|
416
|
-
if convert_dtype:
|
|
417
|
-
# FA2 implementation only supports fp16 and bf16. If FA2 is supported,
|
|
418
|
-
# bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
|
|
419
|
-
orig_dtype = qkv.dtype
|
|
420
|
-
qkv = qkv.to(target_dtype)
|
|
421
|
-
|
|
422
|
-
attn = flash_attn_varlen_qkvpacked_func(
|
|
423
|
-
qkv,
|
|
424
|
-
cu_seqlens=cu_seqlens,
|
|
425
|
-
max_seqlen=max_seqlen,
|
|
426
|
-
dropout_p=module.attention_dropout if module.training else 0.0,
|
|
427
|
-
deterministic=module.deterministic_flash_attn,
|
|
428
|
-
window_size=local_attention,
|
|
429
|
-
)
|
|
430
|
-
attn = attn.to(orig_dtype) # type: ignore
|
|
431
|
-
else:
|
|
432
|
-
attn = flash_attn_varlen_qkvpacked_func(
|
|
433
|
-
qkv,
|
|
434
|
-
cu_seqlens=cu_seqlens,
|
|
435
|
-
max_seqlen=max_seqlen,
|
|
436
|
-
dropout_p=module.attention_dropout if module.training else 0.0,
|
|
437
|
-
deterministic=module.deterministic_flash_attn,
|
|
438
|
-
window_size=local_attention,
|
|
439
|
-
)
|
|
440
|
-
return (attn.view(bs, dim),)
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
def sdpa_attention_forward(
|
|
444
|
-
module: "ModernBertAttention",
|
|
445
|
-
qkv: torch.Tensor,
|
|
446
|
-
attention_mask: torch.Tensor,
|
|
447
|
-
sliding_window_mask: torch.Tensor,
|
|
448
|
-
position_ids: torch.LongTensor | None,
|
|
449
|
-
local_attention: tuple[int, int],
|
|
450
|
-
bs: int,
|
|
451
|
-
dim: int,
|
|
452
|
-
position_embeddings: torch.Tensor,
|
|
453
|
-
**_kwargs,
|
|
454
|
-
) -> tuple[torch.Tensor]:
|
|
455
|
-
# qkv: [batch_size, seqlen, 3, nheads, headdim]
|
|
456
|
-
cos, sin = position_embeddings
|
|
457
|
-
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
|
|
458
|
-
# query, key, value: [batch_size, heads, seq_len, head_dim]
|
|
459
|
-
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
|
460
|
-
|
|
461
|
-
if local_attention != (-1, -1):
|
|
462
|
-
attention_mask = sliding_window_mask
|
|
463
|
-
|
|
464
|
-
attn_output = (
|
|
465
|
-
F.scaled_dot_product_attention(
|
|
466
|
-
query,
|
|
467
|
-
key,
|
|
468
|
-
value,
|
|
469
|
-
dropout_p=module.attention_dropout if module.training else 0.0,
|
|
470
|
-
attn_mask=attention_mask,
|
|
471
|
-
)
|
|
472
|
-
.transpose(1, 2)
|
|
473
|
-
.contiguous()
|
|
474
|
-
)
|
|
475
|
-
attn_output = attn_output.view(bs, -1, dim)
|
|
476
|
-
return (attn_output,)
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
MODERNBERT_ATTENTION_FUNCTION = {
|
|
480
|
-
"flash_attention_2": flash_attention_forward,
|
|
481
|
-
"eager": eager_attention_forward,
|
|
482
|
-
"sdpa": sdpa_attention_forward,
|
|
483
|
-
}
|
|
226
|
+
q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
|
|
227
|
+
k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
|
|
228
|
+
return q_embed.to(original_dtype), k_embed.to(original_dtype)
|
|
484
229
|
|
|
485
230
|
|
|
231
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
486
232
|
class ModernBertAttention(nn.Module):
|
|
487
233
|
"""Performs multi-headed self attention on a batch of unpadded sequences.
|
|
488
234
|
|
|
@@ -493,10 +239,10 @@ class ModernBertAttention(nn.Module):
|
|
|
493
239
|
See `forward` method for additional details.
|
|
494
240
|
"""
|
|
495
241
|
|
|
496
|
-
def __init__(self, config: ModernBertConfig,
|
|
242
|
+
def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
|
|
497
243
|
super().__init__()
|
|
498
244
|
self.config = config
|
|
499
|
-
self.
|
|
245
|
+
self.layer_idx = layer_idx
|
|
500
246
|
|
|
501
247
|
if config.hidden_size % config.num_attention_heads != 0:
|
|
502
248
|
raise ValueError(
|
|
@@ -505,29 +251,19 @@ class ModernBertAttention(nn.Module):
|
|
|
505
251
|
|
|
506
252
|
self.attention_dropout = config.attention_dropout
|
|
507
253
|
self.deterministic_flash_attn = config.deterministic_flash_attn
|
|
508
|
-
self.num_heads = config.num_attention_heads
|
|
509
254
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
510
|
-
self.
|
|
511
|
-
|
|
512
|
-
|
|
255
|
+
self.Wqkv = nn.Linear(
|
|
256
|
+
config.hidden_size, 3 * self.head_dim * config.num_attention_heads, bias=config.attention_bias
|
|
257
|
+
)
|
|
513
258
|
|
|
514
|
-
if
|
|
515
|
-
|
|
516
|
-
|
|
259
|
+
if config.layer_types[layer_idx] == "sliding_attention":
|
|
260
|
+
# config.sliding_window = local_attention // 2 (half-window size, e.g. 64 for local_attention=128)
|
|
261
|
+
# +1 is needed because flash attention sets inclusive boundaries (see modeling_flash_attention_utils.py)
|
|
262
|
+
self.sliding_window = config.sliding_window + 1
|
|
517
263
|
else:
|
|
518
|
-
self.
|
|
519
|
-
max_position_embeddings = config.max_position_embeddings
|
|
264
|
+
self.sliding_window = None
|
|
520
265
|
|
|
521
|
-
|
|
522
|
-
rope_parameters_dict = (
|
|
523
|
-
self.config.rope_parameters[layer_type] if layer_type is not None else self.config.rope_parameters
|
|
524
|
-
)
|
|
525
|
-
rope_theta = rope_parameters_dict["rope_theta"]
|
|
526
|
-
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
|
|
527
|
-
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
|
|
528
|
-
)
|
|
529
|
-
else:
|
|
530
|
-
self.rotary_emb = None
|
|
266
|
+
self.is_causal = False
|
|
531
267
|
|
|
532
268
|
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
|
533
269
|
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
|
@@ -535,82 +271,75 @@ class ModernBertAttention(nn.Module):
|
|
|
535
271
|
def forward(
|
|
536
272
|
self,
|
|
537
273
|
hidden_states: torch.Tensor,
|
|
538
|
-
position_embeddings: torch.Tensor | None = None,
|
|
539
|
-
|
|
540
|
-
**kwargs,
|
|
541
|
-
) -> torch.Tensor:
|
|
274
|
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
275
|
+
attention_mask: torch.Tensor | None = None,
|
|
276
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
277
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
278
|
+
input_shape = hidden_states.shape[:-1]
|
|
279
|
+
|
|
542
280
|
qkv = self.Wqkv(hidden_states)
|
|
281
|
+
qkv = qkv.view(*input_shape, 3, -1, self.head_dim)
|
|
282
|
+
query_states, key_states, value_states = qkv.unbind(dim=-3)
|
|
543
283
|
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
284
|
+
query_states = query_states.transpose(1, 2)
|
|
285
|
+
key_states = key_states.transpose(1, 2)
|
|
286
|
+
value_states = value_states.transpose(1, 2)
|
|
287
|
+
|
|
288
|
+
cos, sin = position_embeddings
|
|
289
|
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)
|
|
549
290
|
|
|
550
|
-
|
|
291
|
+
attention_interface = eager_attention_forward
|
|
292
|
+
if self.config._attn_implementation != "eager":
|
|
293
|
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
294
|
+
|
|
295
|
+
attn_output, attn_weights = attention_interface(
|
|
551
296
|
self,
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
297
|
+
query_states,
|
|
298
|
+
key_states,
|
|
299
|
+
value_states,
|
|
300
|
+
attention_mask,
|
|
301
|
+
dropout=self.attention_dropout if self.training else 0.0,
|
|
302
|
+
scaling=self.head_dim**-0.5,
|
|
303
|
+
sliding_window=self.sliding_window,
|
|
304
|
+
deterministic=self.deterministic_flash_attn,
|
|
559
305
|
**kwargs,
|
|
560
306
|
)
|
|
561
|
-
hidden_states = attn_outputs[0]
|
|
562
|
-
hidden_states = self.out_drop(self.Wo(hidden_states))
|
|
563
307
|
|
|
564
|
-
|
|
308
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
309
|
+
attn_output = self.out_drop(self.Wo(attn_output))
|
|
310
|
+
return attn_output, attn_weights
|
|
565
311
|
|
|
566
312
|
|
|
567
313
|
class ModernBertEncoderLayer(GradientCheckpointingLayer):
|
|
568
|
-
def __init__(self, config: ModernBertConfig,
|
|
314
|
+
def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
|
|
569
315
|
super().__init__()
|
|
570
316
|
self.config = config
|
|
571
|
-
|
|
317
|
+
self.layer_idx = layer_idx
|
|
318
|
+
if layer_idx == 0:
|
|
572
319
|
self.attn_norm = nn.Identity()
|
|
573
320
|
else:
|
|
574
321
|
self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
575
|
-
self.attn = ModernBertAttention(config=config,
|
|
322
|
+
self.attn = ModernBertAttention(config=config, layer_idx=layer_idx)
|
|
576
323
|
self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
577
324
|
self.mlp = ModernBertMLP(config)
|
|
578
|
-
self.attention_type = config.layer_types[
|
|
579
|
-
|
|
580
|
-
@torch.compile(dynamic=True)
|
|
581
|
-
def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
582
|
-
return self.mlp(self.mlp_norm(hidden_states))
|
|
325
|
+
self.attention_type = config.layer_types[layer_idx]
|
|
583
326
|
|
|
584
327
|
def forward(
|
|
585
328
|
self,
|
|
586
329
|
hidden_states: torch.Tensor,
|
|
587
330
|
attention_mask: torch.Tensor | None = None,
|
|
588
|
-
sliding_window_mask: torch.Tensor | None = None,
|
|
589
|
-
position_ids: torch.LongTensor | None = None,
|
|
590
|
-
cu_seqlens: torch.Tensor | None = None,
|
|
591
|
-
max_seqlen: int | None = None,
|
|
592
331
|
position_embeddings: torch.Tensor | None = None,
|
|
593
|
-
|
|
332
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
594
333
|
) -> torch.Tensor:
|
|
595
|
-
|
|
334
|
+
attn_output, _ = self.attn(
|
|
596
335
|
self.attn_norm(hidden_states),
|
|
597
|
-
attention_mask=attention_mask,
|
|
598
|
-
sliding_window_mask=sliding_window_mask,
|
|
599
|
-
position_ids=position_ids,
|
|
600
|
-
cu_seqlens=cu_seqlens,
|
|
601
|
-
max_seqlen=max_seqlen,
|
|
602
336
|
position_embeddings=position_embeddings,
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
hidden_states = hidden_states + attn_outputs[0]
|
|
606
|
-
mlp_output = (
|
|
607
|
-
self.compiled_mlp(hidden_states)
|
|
608
|
-
if self.config.reference_compile
|
|
609
|
-
else self.mlp(self.mlp_norm(hidden_states))
|
|
337
|
+
attention_mask=attention_mask,
|
|
338
|
+
**kwargs,
|
|
610
339
|
)
|
|
611
|
-
hidden_states = hidden_states +
|
|
612
|
-
|
|
613
|
-
return
|
|
340
|
+
hidden_states = hidden_states + attn_output
|
|
341
|
+
hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states))
|
|
342
|
+
return hidden_states
|
|
614
343
|
|
|
615
344
|
|
|
616
345
|
@auto_docstring
|
|
@@ -621,7 +350,13 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|
|
621
350
|
_no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
|
|
622
351
|
_supports_flash_attn = True
|
|
623
352
|
_supports_sdpa = True
|
|
624
|
-
_supports_flex_attn =
|
|
353
|
+
_supports_flex_attn = True
|
|
354
|
+
_supports_attention_backend = True
|
|
355
|
+
|
|
356
|
+
_can_record_outputs = {
|
|
357
|
+
"hidden_states": ModernBertEncoderLayer,
|
|
358
|
+
"attentions": ModernBertAttention,
|
|
359
|
+
}
|
|
625
360
|
|
|
626
361
|
@torch.no_grad()
|
|
627
362
|
def _init_weights(self, module: nn.Module):
|
|
@@ -683,9 +418,6 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|
|
683
418
|
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
684
419
|
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
685
420
|
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
686
|
-
elif isinstance(module, ModernBertUnpaddedRotaryEmbedding):
|
|
687
|
-
inv_freq = module._compute_inv_freq()
|
|
688
|
-
init.copy_(module.inv_freq, inv_freq)
|
|
689
421
|
|
|
690
422
|
def _check_and_adjust_attn_implementation(
|
|
691
423
|
self, attn_implementation: str | None, is_init_check: bool = False
|
|
@@ -693,137 +425,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|
|
693
425
|
"""
|
|
694
426
|
Checks and dispatches to hhe requested attention implementation.
|
|
695
427
|
"""
|
|
696
|
-
# If the user didn't specify anything, try to use flash_attention_2
|
|
428
|
+
# If the user didn't specify anything, try to use flash_attention_2.
|
|
697
429
|
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
|
|
698
|
-
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
|
699
|
-
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
|
700
|
-
|
|
701
430
|
try:
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
else attn_implementation
|
|
431
|
+
requested_attn_implementation = "flash_attention_2" if attn_implementation is None else attn_implementation
|
|
432
|
+
return super()._check_and_adjust_attn_implementation(
|
|
433
|
+
attn_implementation=requested_attn_implementation, is_init_check=is_init_check
|
|
706
434
|
)
|
|
707
435
|
except (ValueError, ImportError):
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
)
|
|
712
|
-
|
|
713
|
-
def _maybe_set_compile(self):
|
|
714
|
-
if self.config.reference_compile is False:
|
|
715
|
-
return
|
|
716
|
-
|
|
717
|
-
if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
|
|
718
|
-
if self.config.reference_compile:
|
|
719
|
-
logger.warning_once(
|
|
720
|
-
"If `accelerate` split the model across devices, `torch.compile` will not work. "
|
|
721
|
-
"Falling back to non-compiled mode."
|
|
722
|
-
)
|
|
723
|
-
self.config.reference_compile = False
|
|
724
|
-
|
|
725
|
-
if self.device.type == "mps":
|
|
726
|
-
if self.config.reference_compile:
|
|
727
|
-
logger.warning_once(
|
|
728
|
-
"Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
|
|
729
|
-
"Falling back to non-compiled mode."
|
|
730
|
-
)
|
|
731
|
-
self.config.reference_compile = False
|
|
732
|
-
|
|
733
|
-
if self.device.type == "cpu":
|
|
734
|
-
if self.config.reference_compile:
|
|
735
|
-
logger.warning_once(
|
|
736
|
-
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
|
|
737
|
-
"Falling back to non-compiled mode."
|
|
738
|
-
)
|
|
739
|
-
self.config.reference_compile = False
|
|
740
|
-
|
|
741
|
-
if self.config.reference_compile is None:
|
|
742
|
-
self.config.reference_compile = is_triton_available()
|
|
743
|
-
|
|
744
|
-
def resize_token_embeddings(self, *args, **kwargs):
|
|
745
|
-
model_embeds = super().resize_token_embeddings(*args, **kwargs)
|
|
746
|
-
|
|
747
|
-
if self.config.reference_compile in {True, None}:
|
|
748
|
-
if self.config.reference_compile:
|
|
749
|
-
logger.warning_once(
|
|
750
|
-
"Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
|
|
751
|
-
)
|
|
752
|
-
self.config.reference_compile = False
|
|
753
|
-
|
|
754
|
-
return model_embeds
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
def _unpad_modernbert_input(
|
|
758
|
-
inputs: torch.Tensor,
|
|
759
|
-
attention_mask: torch.Tensor,
|
|
760
|
-
position_ids: torch.Tensor | None = None,
|
|
761
|
-
labels: torch.Tensor | None = None,
|
|
762
|
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor | None, torch.Tensor | None]:
|
|
763
|
-
"""
|
|
764
|
-
Remove padding from input sequences.
|
|
765
|
-
|
|
766
|
-
Args:
|
|
767
|
-
inputs: (batch, seqlen, ...) or (batch, seqlen)
|
|
768
|
-
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
|
769
|
-
position_ids: (batch, seqlen), int, position ids
|
|
770
|
-
labels: (batch, seqlen), int, labels
|
|
771
|
-
|
|
772
|
-
Returns:
|
|
773
|
-
unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
|
|
774
|
-
indices: (total_nnz)
|
|
775
|
-
cu_seqlens: (batch + 1), the cumulative sequence lengths
|
|
776
|
-
max_seqlen_in_batch: int
|
|
777
|
-
unpadded_position_ids: (total_nnz) or None
|
|
778
|
-
unpadded_labels: (total_nnz) or None
|
|
779
|
-
"""
|
|
780
|
-
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
781
|
-
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
782
|
-
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
|
|
783
|
-
cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
|
784
|
-
|
|
785
|
-
if inputs.dim() == 2:
|
|
786
|
-
unpadded_inputs = inputs.flatten()[indices]
|
|
787
|
-
else:
|
|
788
|
-
batch, seqlen, *rest = inputs.shape
|
|
789
|
-
shape = batch * seqlen
|
|
790
|
-
unpadded_inputs = inputs.view(shape, *rest)[indices]
|
|
791
|
-
|
|
792
|
-
unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
|
|
793
|
-
unpadded_labels = labels.flatten()[indices] if labels is not None else None
|
|
794
|
-
|
|
795
|
-
return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
def _pad_modernbert_output(
|
|
799
|
-
inputs: torch.Tensor,
|
|
800
|
-
indices: torch.Tensor,
|
|
801
|
-
batch: int,
|
|
802
|
-
seqlen: int,
|
|
803
|
-
) -> torch.Tensor:
|
|
804
|
-
"""
|
|
805
|
-
Add padding to sequences.
|
|
806
|
-
|
|
807
|
-
Args:
|
|
808
|
-
inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
|
|
809
|
-
indices: (total_nnz)
|
|
810
|
-
batch: int, batch size
|
|
811
|
-
seqlen: int, max sequence length
|
|
812
|
-
|
|
813
|
-
Returns:
|
|
814
|
-
padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
|
|
815
|
-
"""
|
|
816
|
-
if inputs.dim() == 1:
|
|
817
|
-
output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
|
|
818
|
-
output[indices] = inputs
|
|
819
|
-
padded_inputs = output.view(batch, seqlen)
|
|
820
|
-
else:
|
|
821
|
-
_, *rest = inputs.shape
|
|
822
|
-
output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
|
|
823
|
-
output[indices] = inputs
|
|
824
|
-
padded_inputs = output.view(batch, seqlen, *rest)
|
|
825
|
-
|
|
826
|
-
return padded_inputs
|
|
436
|
+
return super()._check_and_adjust_attn_implementation(
|
|
437
|
+
attn_implementation=attn_implementation, is_init_check=is_init_check
|
|
438
|
+
)
|
|
827
439
|
|
|
828
440
|
|
|
829
441
|
@auto_docstring
|
|
@@ -833,7 +445,7 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
|
|
833
445
|
self.config = config
|
|
834
446
|
self.embeddings = ModernBertEmbeddings(config)
|
|
835
447
|
self.layers = nn.ModuleList(
|
|
836
|
-
[ModernBertEncoderLayer(config,
|
|
448
|
+
[ModernBertEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
837
449
|
)
|
|
838
450
|
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
839
451
|
self.rotary_emb = ModernBertRotaryEmbedding(config=config)
|
|
@@ -846,175 +458,53 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
|
|
846
458
|
def set_input_embeddings(self, value):
|
|
847
459
|
self.embeddings.tok_embeddings = value
|
|
848
460
|
|
|
461
|
+
@check_model_inputs
|
|
849
462
|
@auto_docstring
|
|
850
463
|
def forward(
|
|
851
464
|
self,
|
|
852
465
|
input_ids: torch.LongTensor | None = None,
|
|
853
466
|
attention_mask: torch.Tensor | None = None,
|
|
854
|
-
sliding_window_mask: torch.Tensor | None = None,
|
|
855
467
|
position_ids: torch.LongTensor | None = None,
|
|
856
468
|
inputs_embeds: torch.Tensor | None = None,
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
max_seqlen: int | None = None,
|
|
860
|
-
batch_size: int | None = None,
|
|
861
|
-
seq_len: int | None = None,
|
|
862
|
-
output_attentions: bool | None = None,
|
|
863
|
-
output_hidden_states: bool | None = None,
|
|
864
|
-
return_dict: bool | None = None,
|
|
865
|
-
**kwargs,
|
|
866
|
-
) -> tuple[torch.Tensor, ...] | BaseModelOutput:
|
|
867
|
-
r"""
|
|
868
|
-
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
869
|
-
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
870
|
-
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
871
|
-
far-away tokens in the local attention layers when not using Flash Attention.
|
|
872
|
-
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
873
|
-
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
874
|
-
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
875
|
-
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
876
|
-
max_seqlen (`int`, *optional*):
|
|
877
|
-
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
878
|
-
batch_size (`int`, *optional*):
|
|
879
|
-
Batch size of the input sequences. Used to pad the output tensors.
|
|
880
|
-
seq_len (`int`, *optional*):
|
|
881
|
-
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
882
|
-
"""
|
|
883
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
884
|
-
output_hidden_states = (
|
|
885
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
886
|
-
)
|
|
887
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
888
|
-
|
|
469
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
470
|
+
) -> BaseModelOutput:
|
|
889
471
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
890
472
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
891
473
|
|
|
892
|
-
|
|
893
|
-
all_self_attentions = () if output_attentions else None
|
|
894
|
-
|
|
895
|
-
self._maybe_set_compile()
|
|
896
|
-
|
|
897
|
-
if input_ids is not None:
|
|
898
|
-
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
899
|
-
|
|
900
|
-
if batch_size is None and seq_len is None:
|
|
901
|
-
if inputs_embeds is not None:
|
|
902
|
-
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
903
|
-
else:
|
|
904
|
-
batch_size, seq_len = input_ids.shape[:2]
|
|
474
|
+
seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
|
|
905
475
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
906
476
|
|
|
907
|
-
if
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
repad = False
|
|
911
|
-
if self.config._attn_implementation == "flash_attention_2":
|
|
912
|
-
if indices is None and cu_seqlens is None and max_seqlen is None:
|
|
913
|
-
repad = True
|
|
914
|
-
if inputs_embeds is None:
|
|
915
|
-
with torch.no_grad():
|
|
916
|
-
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
|
917
|
-
inputs=input_ids, attention_mask=attention_mask
|
|
918
|
-
)
|
|
919
|
-
else:
|
|
920
|
-
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
|
921
|
-
inputs=inputs_embeds, attention_mask=attention_mask
|
|
922
|
-
)
|
|
923
|
-
if position_ids is None:
|
|
924
|
-
position_ids = indices.unsqueeze(0)
|
|
925
|
-
else:
|
|
926
|
-
if position_ids is None:
|
|
927
|
-
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
|
|
928
|
-
|
|
929
|
-
attention_mask, sliding_window_mask = self._update_attention_mask(
|
|
930
|
-
attention_mask, output_attentions=output_attentions
|
|
931
|
-
)
|
|
477
|
+
if position_ids is None:
|
|
478
|
+
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
|
|
932
479
|
|
|
933
480
|
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
|
|
481
|
+
|
|
482
|
+
if not isinstance(attention_mask_mapping := attention_mask, dict):
|
|
483
|
+
mask_kwargs = {
|
|
484
|
+
"config": self.config,
|
|
485
|
+
"input_embeds": hidden_states,
|
|
486
|
+
"attention_mask": attention_mask,
|
|
487
|
+
}
|
|
488
|
+
attention_mask_mapping = {
|
|
489
|
+
"full_attention": create_bidirectional_mask(**mask_kwargs),
|
|
490
|
+
"sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
|
|
491
|
+
}
|
|
492
|
+
|
|
934
493
|
position_embeddings = {}
|
|
935
494
|
for layer_type in self.config.layer_types:
|
|
936
495
|
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
|
937
496
|
|
|
938
497
|
for encoder_layer in self.layers:
|
|
939
|
-
|
|
940
|
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
941
|
-
|
|
942
|
-
layer_outputs = encoder_layer(
|
|
498
|
+
hidden_states = encoder_layer(
|
|
943
499
|
hidden_states,
|
|
944
|
-
attention_mask=
|
|
945
|
-
sliding_window_mask=sliding_window_mask,
|
|
946
|
-
position_ids=position_ids,
|
|
947
|
-
cu_seqlens=cu_seqlens,
|
|
948
|
-
max_seqlen=max_seqlen,
|
|
500
|
+
attention_mask=attention_mask_mapping[encoder_layer.attention_type],
|
|
949
501
|
position_embeddings=position_embeddings[encoder_layer.attention_type],
|
|
950
|
-
|
|
502
|
+
**kwargs,
|
|
951
503
|
)
|
|
952
|
-
hidden_states = layer_outputs[0]
|
|
953
|
-
if output_attentions and len(layer_outputs) > 1:
|
|
954
|
-
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
|
955
|
-
|
|
956
|
-
if output_hidden_states:
|
|
957
|
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
958
504
|
|
|
959
505
|
hidden_states = self.final_norm(hidden_states)
|
|
960
506
|
|
|
961
|
-
|
|
962
|
-
hidden_states = _pad_modernbert_output(
|
|
963
|
-
inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
|
|
964
|
-
)
|
|
965
|
-
if all_hidden_states is not None:
|
|
966
|
-
all_hidden_states = tuple(
|
|
967
|
-
_pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
|
|
968
|
-
for hs in all_hidden_states
|
|
969
|
-
)
|
|
970
|
-
# If the attention implementation is FA2 and there is no need for repadding, there might still be the batch
|
|
971
|
-
# dimension missing
|
|
972
|
-
elif (
|
|
973
|
-
self.config._attn_implementation == "flash_attention_2"
|
|
974
|
-
and all_hidden_states is not None
|
|
975
|
-
and all_hidden_states[-1].dim() == 2
|
|
976
|
-
):
|
|
977
|
-
hidden_states = hidden_states.unsqueeze(0)
|
|
978
|
-
all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states)
|
|
979
|
-
|
|
980
|
-
if not return_dict:
|
|
981
|
-
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
|
982
|
-
return BaseModelOutput(
|
|
983
|
-
last_hidden_state=hidden_states,
|
|
984
|
-
hidden_states=all_hidden_states,
|
|
985
|
-
attentions=all_self_attentions,
|
|
986
|
-
)
|
|
987
|
-
|
|
988
|
-
def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
|
|
989
|
-
if output_attentions:
|
|
990
|
-
if self.config._attn_implementation == "sdpa":
|
|
991
|
-
logger.warning_once(
|
|
992
|
-
"Outputting attentions is only supported with the 'eager' attention implementation, "
|
|
993
|
-
'not with "sdpa". Falling back to `attn_implementation="eager"`.'
|
|
994
|
-
)
|
|
995
|
-
self.config._attn_implementation = "eager"
|
|
996
|
-
elif self.config._attn_implementation != "eager":
|
|
997
|
-
logger.warning_once(
|
|
998
|
-
"Outputting attentions is only supported with the eager attention implementation, "
|
|
999
|
-
f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
|
|
1000
|
-
" Setting `output_attentions=False`."
|
|
1001
|
-
)
|
|
1002
|
-
|
|
1003
|
-
global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
|
|
1004
|
-
|
|
1005
|
-
# Create position indices
|
|
1006
|
-
rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
|
|
1007
|
-
# Calculate distance between positions
|
|
1008
|
-
distance = torch.abs(rows - rows.T)
|
|
1009
|
-
|
|
1010
|
-
# Create sliding window mask (1 for positions within window, 0 outside)
|
|
1011
|
-
window_mask = (
|
|
1012
|
-
(distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
|
|
1013
|
-
)
|
|
1014
|
-
# Combine with existing mask
|
|
1015
|
-
sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
|
|
1016
|
-
|
|
1017
|
-
return global_attention_mask, sliding_window_mask
|
|
507
|
+
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
1018
508
|
|
|
1019
509
|
|
|
1020
510
|
class ModernBertPredictionHead(nn.Module):
|
|
@@ -1056,84 +546,23 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
|
|
1056
546
|
def set_output_embeddings(self, new_embeddings: nn.Linear):
|
|
1057
547
|
self.decoder = new_embeddings
|
|
1058
548
|
|
|
1059
|
-
@
|
|
1060
|
-
def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
|
|
1061
|
-
return self.decoder(self.head(output))
|
|
1062
|
-
|
|
549
|
+
@can_return_tuple
|
|
1063
550
|
@auto_docstring
|
|
1064
551
|
def forward(
|
|
1065
552
|
self,
|
|
1066
553
|
input_ids: torch.LongTensor | None = None,
|
|
1067
554
|
attention_mask: torch.Tensor | None = None,
|
|
1068
|
-
sliding_window_mask: torch.Tensor | None = None,
|
|
1069
555
|
position_ids: torch.Tensor | None = None,
|
|
1070
556
|
inputs_embeds: torch.Tensor | None = None,
|
|
1071
557
|
labels: torch.Tensor | None = None,
|
|
1072
|
-
|
|
1073
|
-
cu_seqlens: torch.Tensor | None = None,
|
|
1074
|
-
max_seqlen: int | None = None,
|
|
1075
|
-
batch_size: int | None = None,
|
|
1076
|
-
seq_len: int | None = None,
|
|
1077
|
-
output_attentions: bool | None = None,
|
|
1078
|
-
output_hidden_states: bool | None = None,
|
|
1079
|
-
return_dict: bool | None = None,
|
|
1080
|
-
**kwargs,
|
|
558
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1081
559
|
) -> tuple[torch.Tensor] | MaskedLMOutput:
|
|
1082
|
-
r"""
|
|
1083
|
-
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1084
|
-
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
1085
|
-
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
1086
|
-
far-away tokens in the local attention layers when not using Flash Attention.
|
|
1087
|
-
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
1088
|
-
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
1089
|
-
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
1090
|
-
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
1091
|
-
max_seqlen (`int`, *optional*):
|
|
1092
|
-
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
1093
|
-
batch_size (`int`, *optional*):
|
|
1094
|
-
Batch size of the input sequences. Used to pad the output tensors.
|
|
1095
|
-
seq_len (`int`, *optional*):
|
|
1096
|
-
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
1097
|
-
"""
|
|
1098
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1099
|
-
self._maybe_set_compile()
|
|
1100
|
-
|
|
1101
|
-
if self.config._attn_implementation == "flash_attention_2":
|
|
1102
|
-
if indices is None and cu_seqlens is None and max_seqlen is None:
|
|
1103
|
-
if batch_size is None and seq_len is None:
|
|
1104
|
-
if inputs_embeds is not None:
|
|
1105
|
-
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
1106
|
-
else:
|
|
1107
|
-
batch_size, seq_len = input_ids.shape[:2]
|
|
1108
|
-
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
1109
|
-
|
|
1110
|
-
if attention_mask is None:
|
|
1111
|
-
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
|
|
1112
|
-
|
|
1113
|
-
if inputs_embeds is None:
|
|
1114
|
-
with torch.no_grad():
|
|
1115
|
-
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
|
1116
|
-
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
|
1117
|
-
)
|
|
1118
|
-
else:
|
|
1119
|
-
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
|
1120
|
-
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
|
1121
|
-
)
|
|
1122
|
-
|
|
1123
560
|
outputs = self.model(
|
|
1124
561
|
input_ids=input_ids,
|
|
1125
562
|
attention_mask=attention_mask,
|
|
1126
|
-
sliding_window_mask=sliding_window_mask,
|
|
1127
563
|
position_ids=position_ids,
|
|
1128
564
|
inputs_embeds=inputs_embeds,
|
|
1129
|
-
|
|
1130
|
-
cu_seqlens=cu_seqlens,
|
|
1131
|
-
max_seqlen=max_seqlen,
|
|
1132
|
-
batch_size=batch_size,
|
|
1133
|
-
seq_len=seq_len,
|
|
1134
|
-
output_attentions=output_attentions,
|
|
1135
|
-
output_hidden_states=output_hidden_states,
|
|
1136
|
-
return_dict=return_dict,
|
|
565
|
+
**kwargs,
|
|
1137
566
|
)
|
|
1138
567
|
last_hidden_state = outputs[0]
|
|
1139
568
|
|
|
@@ -1147,35 +576,12 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
|
|
1147
576
|
last_hidden_state = last_hidden_state[mask_tokens]
|
|
1148
577
|
labels = labels[mask_tokens]
|
|
1149
578
|
|
|
1150
|
-
logits = (
|
|
1151
|
-
self.compiled_head(last_hidden_state)
|
|
1152
|
-
if self.config.reference_compile
|
|
1153
|
-
else self.decoder(self.head(last_hidden_state))
|
|
1154
|
-
)
|
|
579
|
+
logits = self.decoder(self.head(last_hidden_state))
|
|
1155
580
|
|
|
1156
581
|
loss = None
|
|
1157
582
|
if labels is not None:
|
|
1158
583
|
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
|
|
1159
584
|
|
|
1160
|
-
if self.config._attn_implementation == "flash_attention_2":
|
|
1161
|
-
# Logits padding
|
|
1162
|
-
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
|
|
1163
|
-
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
|
|
1164
|
-
# Hidden states padding
|
|
1165
|
-
if getattr(outputs, "hidden_states", None) is not None:
|
|
1166
|
-
padded_hidden_states = []
|
|
1167
|
-
for hs in outputs.hidden_states:
|
|
1168
|
-
if hs.dim() == 3 and hs.shape[0] == 1:
|
|
1169
|
-
hs = hs.squeeze(0)
|
|
1170
|
-
padded_hidden_states.append(
|
|
1171
|
-
_pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
|
|
1172
|
-
)
|
|
1173
|
-
outputs.hidden_states = tuple(padded_hidden_states)
|
|
1174
|
-
|
|
1175
|
-
if not return_dict:
|
|
1176
|
-
output = (logits,)
|
|
1177
|
-
return ((loss,) + output) if loss is not None else output
|
|
1178
|
-
|
|
1179
585
|
return MaskedLMOutput(
|
|
1180
586
|
loss=loss,
|
|
1181
587
|
logits=logits,
|
|
@@ -1203,81 +609,39 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
|
|
|
1203
609
|
# Initialize weights and apply final processing
|
|
1204
610
|
self.post_init()
|
|
1205
611
|
|
|
612
|
+
@can_return_tuple
|
|
1206
613
|
@auto_docstring
|
|
1207
614
|
def forward(
|
|
1208
615
|
self,
|
|
1209
616
|
input_ids: torch.LongTensor | None = None,
|
|
1210
617
|
attention_mask: torch.Tensor | None = None,
|
|
1211
|
-
sliding_window_mask: torch.Tensor | None = None,
|
|
1212
618
|
position_ids: torch.Tensor | None = None,
|
|
1213
619
|
inputs_embeds: torch.Tensor | None = None,
|
|
1214
620
|
labels: torch.Tensor | None = None,
|
|
1215
|
-
|
|
1216
|
-
cu_seqlens: torch.Tensor | None = None,
|
|
1217
|
-
max_seqlen: int | None = None,
|
|
1218
|
-
batch_size: int | None = None,
|
|
1219
|
-
seq_len: int | None = None,
|
|
1220
|
-
output_attentions: bool | None = None,
|
|
1221
|
-
output_hidden_states: bool | None = None,
|
|
1222
|
-
return_dict: bool | None = None,
|
|
1223
|
-
**kwargs,
|
|
621
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1224
622
|
) -> tuple[torch.Tensor] | SequenceClassifierOutput:
|
|
1225
623
|
r"""
|
|
1226
|
-
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1227
|
-
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
1228
|
-
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
1229
|
-
far-away tokens in the local attention layers when not using Flash Attention.
|
|
1230
624
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
1231
625
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
1232
626
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
1233
627
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
1234
|
-
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
1235
|
-
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
1236
|
-
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
1237
|
-
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
1238
|
-
max_seqlen (`int`, *optional*):
|
|
1239
|
-
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
1240
|
-
batch_size (`int`, *optional*):
|
|
1241
|
-
Batch size of the input sequences. Used to pad the output tensors.
|
|
1242
|
-
seq_len (`int`, *optional*):
|
|
1243
|
-
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
1244
628
|
"""
|
|
1245
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1246
|
-
self._maybe_set_compile()
|
|
1247
|
-
|
|
1248
|
-
if input_ids is not None:
|
|
1249
|
-
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
1250
|
-
|
|
1251
|
-
if batch_size is None and seq_len is None:
|
|
1252
|
-
if inputs_embeds is not None:
|
|
1253
|
-
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
1254
|
-
else:
|
|
1255
|
-
batch_size, seq_len = input_ids.shape[:2]
|
|
1256
|
-
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
1257
|
-
|
|
1258
|
-
if attention_mask is None:
|
|
1259
|
-
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
|
|
1260
|
-
|
|
1261
629
|
outputs = self.model(
|
|
1262
630
|
input_ids=input_ids,
|
|
1263
631
|
attention_mask=attention_mask,
|
|
1264
|
-
sliding_window_mask=sliding_window_mask,
|
|
1265
632
|
position_ids=position_ids,
|
|
1266
633
|
inputs_embeds=inputs_embeds,
|
|
1267
|
-
|
|
1268
|
-
cu_seqlens=cu_seqlens,
|
|
1269
|
-
max_seqlen=max_seqlen,
|
|
1270
|
-
batch_size=batch_size,
|
|
1271
|
-
seq_len=seq_len,
|
|
1272
|
-
output_attentions=output_attentions,
|
|
1273
|
-
output_hidden_states=output_hidden_states,
|
|
1274
|
-
return_dict=return_dict,
|
|
634
|
+
**kwargs,
|
|
1275
635
|
)
|
|
1276
636
|
last_hidden_state = outputs[0]
|
|
1277
637
|
|
|
1278
638
|
if self.config.classifier_pooling == "cls":
|
|
1279
639
|
last_hidden_state = last_hidden_state[:, 0]
|
|
1280
640
|
elif self.config.classifier_pooling == "mean":
|
|
641
|
+
if attention_mask is None:
|
|
642
|
+
attention_mask = torch.ones(
|
|
643
|
+
last_hidden_state.shape[:2], device=last_hidden_state.device, dtype=torch.bool
|
|
644
|
+
)
|
|
1281
645
|
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
|
|
1282
646
|
dim=1, keepdim=True
|
|
1283
647
|
)
|
|
@@ -1309,10 +673,6 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
|
|
|
1309
673
|
loss_fct = BCEWithLogitsLoss()
|
|
1310
674
|
loss = loss_fct(logits, labels)
|
|
1311
675
|
|
|
1312
|
-
if not return_dict:
|
|
1313
|
-
output = (logits,)
|
|
1314
|
-
return ((loss,) + output) if loss is not None else output
|
|
1315
|
-
|
|
1316
676
|
return SequenceClassifierOutput(
|
|
1317
677
|
loss=loss,
|
|
1318
678
|
logits=logits,
|
|
@@ -1339,60 +699,27 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
|
|
1339
699
|
# Initialize weights and apply final processing
|
|
1340
700
|
self.post_init()
|
|
1341
701
|
|
|
702
|
+
@can_return_tuple
|
|
1342
703
|
@auto_docstring
|
|
1343
704
|
def forward(
|
|
1344
705
|
self,
|
|
1345
706
|
input_ids: torch.LongTensor | None = None,
|
|
1346
707
|
attention_mask: torch.Tensor | None = None,
|
|
1347
|
-
sliding_window_mask: torch.Tensor | None = None,
|
|
1348
708
|
position_ids: torch.Tensor | None = None,
|
|
1349
709
|
inputs_embeds: torch.Tensor | None = None,
|
|
1350
710
|
labels: torch.Tensor | None = None,
|
|
1351
|
-
|
|
1352
|
-
cu_seqlens: torch.Tensor | None = None,
|
|
1353
|
-
max_seqlen: int | None = None,
|
|
1354
|
-
batch_size: int | None = None,
|
|
1355
|
-
seq_len: int | None = None,
|
|
1356
|
-
output_attentions: bool | None = None,
|
|
1357
|
-
output_hidden_states: bool | None = None,
|
|
1358
|
-
return_dict: bool | None = None,
|
|
1359
|
-
**kwargs,
|
|
711
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1360
712
|
) -> tuple[torch.Tensor] | TokenClassifierOutput:
|
|
1361
713
|
r"""
|
|
1362
|
-
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1363
|
-
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
1364
|
-
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
1365
|
-
far-away tokens in the local attention layers when not using Flash Attention.
|
|
1366
714
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1367
715
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
|
1368
|
-
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
1369
|
-
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
1370
|
-
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
1371
|
-
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
1372
|
-
max_seqlen (`int`, *optional*):
|
|
1373
|
-
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
1374
|
-
batch_size (`int`, *optional*):
|
|
1375
|
-
Batch size of the input sequences. Used to pad the output tensors.
|
|
1376
|
-
seq_len (`int`, *optional*):
|
|
1377
|
-
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
1378
716
|
"""
|
|
1379
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1380
|
-
self._maybe_set_compile()
|
|
1381
|
-
|
|
1382
717
|
outputs = self.model(
|
|
1383
718
|
input_ids=input_ids,
|
|
1384
719
|
attention_mask=attention_mask,
|
|
1385
|
-
sliding_window_mask=sliding_window_mask,
|
|
1386
720
|
position_ids=position_ids,
|
|
1387
721
|
inputs_embeds=inputs_embeds,
|
|
1388
|
-
|
|
1389
|
-
cu_seqlens=cu_seqlens,
|
|
1390
|
-
max_seqlen=max_seqlen,
|
|
1391
|
-
batch_size=batch_size,
|
|
1392
|
-
seq_len=seq_len,
|
|
1393
|
-
output_attentions=output_attentions,
|
|
1394
|
-
output_hidden_states=output_hidden_states,
|
|
1395
|
-
return_dict=return_dict,
|
|
722
|
+
**kwargs,
|
|
1396
723
|
)
|
|
1397
724
|
last_hidden_state = outputs[0]
|
|
1398
725
|
|
|
@@ -1405,10 +732,6 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
|
|
1405
732
|
loss_fct = CrossEntropyLoss()
|
|
1406
733
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
1407
734
|
|
|
1408
|
-
if not return_dict:
|
|
1409
|
-
output = (logits,) + outputs[1:]
|
|
1410
|
-
return ((loss,) + output) if loss is not None else output
|
|
1411
|
-
|
|
1412
735
|
return TokenClassifierOutput(
|
|
1413
736
|
loss=loss,
|
|
1414
737
|
logits=logits,
|
|
@@ -1430,57 +753,22 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
|
1430
753
|
|
|
1431
754
|
self.post_init()
|
|
1432
755
|
|
|
756
|
+
@can_return_tuple
|
|
1433
757
|
@auto_docstring
|
|
1434
758
|
def forward(
|
|
1435
759
|
self,
|
|
1436
|
-
input_ids: torch.Tensor | None,
|
|
760
|
+
input_ids: torch.Tensor | None = None,
|
|
1437
761
|
attention_mask: torch.Tensor | None = None,
|
|
1438
|
-
sliding_window_mask: torch.Tensor | None = None,
|
|
1439
762
|
position_ids: torch.Tensor | None = None,
|
|
1440
763
|
start_positions: torch.Tensor | None = None,
|
|
1441
764
|
end_positions: torch.Tensor | None = None,
|
|
1442
|
-
|
|
1443
|
-
cu_seqlens: torch.Tensor | None = None,
|
|
1444
|
-
max_seqlen: int | None = None,
|
|
1445
|
-
batch_size: int | None = None,
|
|
1446
|
-
seq_len: int | None = None,
|
|
1447
|
-
output_attentions: bool | None = None,
|
|
1448
|
-
output_hidden_states: bool | None = None,
|
|
1449
|
-
return_dict: bool | None = None,
|
|
1450
|
-
**kwargs,
|
|
765
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1451
766
|
) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
|
|
1452
|
-
r"""
|
|
1453
|
-
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1454
|
-
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
1455
|
-
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
1456
|
-
far-away tokens in the local attention layers when not using Flash Attention.
|
|
1457
|
-
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
1458
|
-
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
1459
|
-
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
1460
|
-
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
1461
|
-
max_seqlen (`int`, *optional*):
|
|
1462
|
-
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
1463
|
-
batch_size (`int`, *optional*):
|
|
1464
|
-
Batch size of the input sequences. Used to pad the output tensors.
|
|
1465
|
-
seq_len (`int`, *optional*):
|
|
1466
|
-
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
1467
|
-
"""
|
|
1468
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1469
|
-
self._maybe_set_compile()
|
|
1470
|
-
|
|
1471
767
|
outputs = self.model(
|
|
1472
768
|
input_ids,
|
|
1473
769
|
attention_mask=attention_mask,
|
|
1474
|
-
sliding_window_mask=sliding_window_mask,
|
|
1475
770
|
position_ids=position_ids,
|
|
1476
|
-
|
|
1477
|
-
cu_seqlens=cu_seqlens,
|
|
1478
|
-
max_seqlen=max_seqlen,
|
|
1479
|
-
batch_size=batch_size,
|
|
1480
|
-
seq_len=seq_len,
|
|
1481
|
-
output_attentions=output_attentions,
|
|
1482
|
-
output_hidden_states=output_hidden_states,
|
|
1483
|
-
return_dict=return_dict,
|
|
771
|
+
**kwargs,
|
|
1484
772
|
)
|
|
1485
773
|
last_hidden_state = outputs[0]
|
|
1486
774
|
|
|
@@ -1496,10 +784,6 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
|
1496
784
|
if start_positions is not None and end_positions is not None:
|
|
1497
785
|
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
|
1498
786
|
|
|
1499
|
-
if not return_dict:
|
|
1500
|
-
output = (start_logits, end_logits) + outputs[1:]
|
|
1501
|
-
return ((loss,) + output) if loss is not None else output
|
|
1502
|
-
|
|
1503
787
|
return QuestionAnsweringModelOutput(
|
|
1504
788
|
loss=loss,
|
|
1505
789
|
start_logits=start_logits,
|
|
@@ -1527,45 +811,22 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
|
|
|
1527
811
|
# Initialize weights and apply final processing
|
|
1528
812
|
self.post_init()
|
|
1529
813
|
|
|
814
|
+
@can_return_tuple
|
|
1530
815
|
@auto_docstring
|
|
1531
816
|
def forward(
|
|
1532
817
|
self,
|
|
1533
818
|
input_ids: torch.LongTensor | None = None,
|
|
1534
819
|
attention_mask: torch.Tensor | None = None,
|
|
1535
|
-
sliding_window_mask: torch.Tensor | None = None,
|
|
1536
820
|
position_ids: torch.Tensor | None = None,
|
|
1537
821
|
inputs_embeds: torch.Tensor | None = None,
|
|
1538
822
|
labels: torch.Tensor | None = None,
|
|
1539
|
-
|
|
1540
|
-
cu_seqlens: torch.Tensor | None = None,
|
|
1541
|
-
max_seqlen: int | None = None,
|
|
1542
|
-
batch_size: int | None = None,
|
|
1543
|
-
seq_len: int | None = None,
|
|
1544
|
-
output_attentions: bool | None = None,
|
|
1545
|
-
output_hidden_states: bool | None = None,
|
|
1546
|
-
return_dict: bool | None = None,
|
|
1547
|
-
**kwargs,
|
|
823
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1548
824
|
) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
|
|
1549
825
|
r"""
|
|
1550
|
-
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1551
|
-
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
1552
|
-
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
1553
|
-
far-away tokens in the local attention layers when not using Flash Attention.
|
|
1554
826
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
1555
827
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
|
1556
828
|
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
|
|
1557
|
-
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
1558
|
-
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
1559
|
-
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
1560
|
-
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
1561
|
-
max_seqlen (`int`, *optional*):
|
|
1562
|
-
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
1563
|
-
batch_size (`int`, *optional*):
|
|
1564
|
-
Batch size of the input sequences. Used to pad the output tensors.
|
|
1565
|
-
seq_len (`int`, *optional*):
|
|
1566
|
-
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
1567
829
|
"""
|
|
1568
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1569
830
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
|
1570
831
|
|
|
1571
832
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
|
@@ -1577,22 +838,12 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
|
|
|
1577
838
|
else None
|
|
1578
839
|
)
|
|
1579
840
|
|
|
1580
|
-
self._maybe_set_compile()
|
|
1581
|
-
|
|
1582
841
|
outputs = self.model(
|
|
1583
842
|
input_ids=input_ids,
|
|
1584
843
|
attention_mask=attention_mask,
|
|
1585
|
-
sliding_window_mask=sliding_window_mask,
|
|
1586
844
|
position_ids=position_ids,
|
|
1587
845
|
inputs_embeds=inputs_embeds,
|
|
1588
|
-
|
|
1589
|
-
cu_seqlens=cu_seqlens,
|
|
1590
|
-
max_seqlen=max_seqlen,
|
|
1591
|
-
batch_size=batch_size,
|
|
1592
|
-
seq_len=seq_len,
|
|
1593
|
-
output_attentions=output_attentions,
|
|
1594
|
-
output_hidden_states=output_hidden_states,
|
|
1595
|
-
return_dict=return_dict,
|
|
846
|
+
**kwargs,
|
|
1596
847
|
)
|
|
1597
848
|
last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
|
|
1598
849
|
|
|
@@ -1624,10 +875,6 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
|
|
|
1624
875
|
loss_fct = nn.CrossEntropyLoss()
|
|
1625
876
|
loss = loss_fct(reshaped_logits, labels)
|
|
1626
877
|
|
|
1627
|
-
if not return_dict:
|
|
1628
|
-
output = (reshaped_logits,) + outputs[1:]
|
|
1629
|
-
return ((loss,) + output) if loss is not None else output
|
|
1630
|
-
|
|
1631
878
|
return MultipleChoiceModelOutput(
|
|
1632
879
|
loss=loss,
|
|
1633
880
|
logits=reshaped_logits,
|