transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +11 -37
- transformers/activations.py +2 -2
- transformers/audio_utils.py +32 -32
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +26 -126
- transformers/cli/chat.py +3 -3
- transformers/cli/serve.py +13 -10
- transformers/cli/transformers.py +2 -1
- transformers/configuration_utils.py +22 -92
- transformers/conversion_mapping.py +150 -26
- transformers/convert_slow_tokenizer.py +9 -12
- transformers/core_model_loading.py +217 -129
- transformers/data/processors/glue.py +0 -1
- transformers/data/processors/utils.py +0 -1
- transformers/data/processors/xnli.py +0 -1
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +10 -11
- transformers/distributed/configuration_utils.py +1 -2
- transformers/dynamic_module_utils.py +23 -23
- transformers/feature_extraction_sequence_utils.py +19 -23
- transformers/feature_extraction_utils.py +14 -14
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +2 -4
- transformers/generation/configuration_utils.py +54 -39
- transformers/generation/continuous_batching/__init__.py +0 -1
- transformers/generation/continuous_batching/cache.py +74 -44
- transformers/generation/continuous_batching/cache_manager.py +28 -28
- transformers/generation/continuous_batching/continuous_api.py +133 -414
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +77 -19
- transformers/generation/continuous_batching/scheduler.py +154 -104
- transformers/generation/logits_process.py +10 -133
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/streamers.py +0 -1
- transformers/generation/utils.py +91 -121
- transformers/generation/watermarking.py +2 -3
- transformers/hf_argparser.py +9 -13
- transformers/hyperparameter_search.py +1 -2
- transformers/image_processing_base.py +9 -9
- transformers/image_processing_utils.py +11 -15
- transformers/image_processing_utils_fast.py +70 -71
- transformers/image_transforms.py +73 -42
- transformers/image_utils.py +30 -37
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/awq.py +1 -3
- transformers/integrations/deepspeed.py +146 -4
- transformers/integrations/eetq.py +0 -1
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/fbgemm_fp8.py +1 -2
- transformers/integrations/finegrained_fp8.py +149 -13
- transformers/integrations/flash_attention.py +3 -8
- transformers/integrations/flex_attention.py +1 -1
- transformers/integrations/fp_quant.py +4 -6
- transformers/integrations/ggml.py +0 -1
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/integration_utils.py +2 -3
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +52 -40
- transformers/integrations/peft.py +488 -176
- transformers/integrations/quark.py +2 -4
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/integrations/torchao.py +4 -6
- transformers/loss/loss_lw_detr.py +356 -0
- transformers/loss/loss_utils.py +2 -0
- transformers/masking_utils.py +199 -59
- transformers/model_debugging_utils.py +4 -5
- transformers/modelcard.py +14 -192
- transformers/modeling_attn_mask_utils.py +19 -19
- transformers/modeling_flash_attention_utils.py +28 -29
- transformers/modeling_gguf_pytorch_utils.py +5 -5
- transformers/modeling_layers.py +21 -22
- transformers/modeling_outputs.py +242 -253
- transformers/modeling_rope_utils.py +32 -32
- transformers/modeling_utils.py +416 -438
- transformers/models/__init__.py +10 -0
- transformers/models/afmoe/configuration_afmoe.py +40 -33
- transformers/models/afmoe/modeling_afmoe.py +38 -41
- transformers/models/afmoe/modular_afmoe.py +23 -25
- transformers/models/aimv2/configuration_aimv2.py +2 -10
- transformers/models/aimv2/modeling_aimv2.py +46 -45
- transformers/models/aimv2/modular_aimv2.py +13 -19
- transformers/models/albert/configuration_albert.py +8 -2
- transformers/models/albert/modeling_albert.py +70 -72
- transformers/models/albert/tokenization_albert.py +1 -4
- transformers/models/align/configuration_align.py +8 -6
- transformers/models/align/modeling_align.py +83 -86
- transformers/models/align/processing_align.py +2 -30
- transformers/models/altclip/configuration_altclip.py +4 -7
- transformers/models/altclip/modeling_altclip.py +106 -103
- transformers/models/altclip/processing_altclip.py +2 -15
- transformers/models/apertus/__init__.py +0 -1
- transformers/models/apertus/configuration_apertus.py +23 -28
- transformers/models/apertus/modeling_apertus.py +35 -38
- transformers/models/apertus/modular_apertus.py +36 -40
- transformers/models/arcee/configuration_arcee.py +25 -30
- transformers/models/arcee/modeling_arcee.py +35 -38
- transformers/models/arcee/modular_arcee.py +20 -23
- transformers/models/aria/configuration_aria.py +31 -44
- transformers/models/aria/image_processing_aria.py +25 -27
- transformers/models/aria/modeling_aria.py +102 -102
- transformers/models/aria/modular_aria.py +111 -124
- transformers/models/aria/processing_aria.py +28 -35
- transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
- transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
- transformers/models/audioflamingo3/__init__.py +0 -1
- transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
- transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
- transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
- transformers/models/auto/auto_factory.py +12 -11
- transformers/models/auto/configuration_auto.py +48 -5
- transformers/models/auto/feature_extraction_auto.py +5 -7
- transformers/models/auto/image_processing_auto.py +30 -39
- transformers/models/auto/modeling_auto.py +33 -199
- transformers/models/auto/processing_auto.py +11 -19
- transformers/models/auto/tokenization_auto.py +38 -37
- transformers/models/auto/video_processing_auto.py +7 -8
- transformers/models/autoformer/configuration_autoformer.py +4 -7
- transformers/models/autoformer/modeling_autoformer.py +100 -101
- transformers/models/aya_vision/configuration_aya_vision.py +4 -1
- transformers/models/aya_vision/modeling_aya_vision.py +64 -99
- transformers/models/aya_vision/modular_aya_vision.py +46 -74
- transformers/models/aya_vision/processing_aya_vision.py +25 -53
- transformers/models/bamba/configuration_bamba.py +46 -39
- transformers/models/bamba/modeling_bamba.py +83 -119
- transformers/models/bamba/modular_bamba.py +70 -109
- transformers/models/bark/configuration_bark.py +6 -8
- transformers/models/bark/generation_configuration_bark.py +3 -5
- transformers/models/bark/modeling_bark.py +64 -65
- transformers/models/bark/processing_bark.py +19 -41
- transformers/models/bart/configuration_bart.py +9 -5
- transformers/models/bart/modeling_bart.py +124 -129
- transformers/models/barthez/tokenization_barthez.py +1 -4
- transformers/models/bartpho/tokenization_bartpho.py +6 -7
- transformers/models/beit/configuration_beit.py +2 -15
- transformers/models/beit/image_processing_beit.py +53 -56
- transformers/models/beit/image_processing_beit_fast.py +11 -12
- transformers/models/beit/modeling_beit.py +65 -62
- transformers/models/bert/configuration_bert.py +12 -2
- transformers/models/bert/modeling_bert.py +117 -152
- transformers/models/bert/tokenization_bert.py +2 -4
- transformers/models/bert/tokenization_bert_legacy.py +3 -5
- transformers/models/bert_generation/configuration_bert_generation.py +17 -2
- transformers/models/bert_generation/modeling_bert_generation.py +53 -55
- transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
- transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
- transformers/models/bertweet/tokenization_bertweet.py +1 -3
- transformers/models/big_bird/configuration_big_bird.py +12 -9
- transformers/models/big_bird/modeling_big_bird.py +107 -124
- transformers/models/big_bird/tokenization_big_bird.py +1 -4
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
- transformers/models/biogpt/configuration_biogpt.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +73 -79
- transformers/models/biogpt/modular_biogpt.py +60 -66
- transformers/models/biogpt/tokenization_biogpt.py +3 -5
- transformers/models/bit/configuration_bit.py +2 -5
- transformers/models/bit/image_processing_bit.py +21 -24
- transformers/models/bit/image_processing_bit_fast.py +0 -1
- transformers/models/bit/modeling_bit.py +15 -16
- transformers/models/bitnet/configuration_bitnet.py +23 -28
- transformers/models/bitnet/modeling_bitnet.py +34 -38
- transformers/models/bitnet/modular_bitnet.py +7 -10
- transformers/models/blenderbot/configuration_blenderbot.py +8 -5
- transformers/models/blenderbot/modeling_blenderbot.py +68 -99
- transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
- transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
- transformers/models/blip/configuration_blip.py +9 -10
- transformers/models/blip/image_processing_blip.py +17 -20
- transformers/models/blip/image_processing_blip_fast.py +0 -1
- transformers/models/blip/modeling_blip.py +115 -108
- transformers/models/blip/modeling_blip_text.py +63 -65
- transformers/models/blip/processing_blip.py +5 -36
- transformers/models/blip_2/configuration_blip_2.py +2 -2
- transformers/models/blip_2/modeling_blip_2.py +145 -121
- transformers/models/blip_2/processing_blip_2.py +8 -38
- transformers/models/bloom/configuration_bloom.py +5 -2
- transformers/models/bloom/modeling_bloom.py +60 -60
- transformers/models/blt/configuration_blt.py +94 -86
- transformers/models/blt/modeling_blt.py +93 -90
- transformers/models/blt/modular_blt.py +127 -69
- transformers/models/bridgetower/configuration_bridgetower.py +7 -2
- transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
- transformers/models/bridgetower/modeling_bridgetower.py +136 -124
- transformers/models/bridgetower/processing_bridgetower.py +2 -16
- transformers/models/bros/configuration_bros.py +24 -18
- transformers/models/bros/modeling_bros.py +78 -80
- transformers/models/bros/processing_bros.py +2 -12
- transformers/models/byt5/tokenization_byt5.py +4 -6
- transformers/models/camembert/configuration_camembert.py +8 -2
- transformers/models/camembert/modeling_camembert.py +97 -99
- transformers/models/camembert/modular_camembert.py +51 -54
- transformers/models/camembert/tokenization_camembert.py +1 -4
- transformers/models/canine/configuration_canine.py +4 -2
- transformers/models/canine/modeling_canine.py +73 -75
- transformers/models/canine/tokenization_canine.py +0 -1
- transformers/models/chameleon/configuration_chameleon.py +29 -34
- transformers/models/chameleon/image_processing_chameleon.py +21 -24
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
- transformers/models/chameleon/modeling_chameleon.py +135 -92
- transformers/models/chameleon/processing_chameleon.py +16 -41
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
- transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
- transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
- transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
- transformers/models/clap/configuration_clap.py +4 -9
- transformers/models/clap/feature_extraction_clap.py +9 -10
- transformers/models/clap/modeling_clap.py +109 -111
- transformers/models/clap/processing_clap.py +2 -15
- transformers/models/clip/configuration_clip.py +4 -2
- transformers/models/clip/image_processing_clip.py +21 -24
- transformers/models/clip/image_processing_clip_fast.py +9 -1
- transformers/models/clip/modeling_clip.py +70 -68
- transformers/models/clip/processing_clip.py +2 -14
- transformers/models/clip/tokenization_clip.py +2 -5
- transformers/models/clipseg/configuration_clipseg.py +4 -2
- transformers/models/clipseg/modeling_clipseg.py +113 -112
- transformers/models/clipseg/processing_clipseg.py +19 -42
- transformers/models/clvp/configuration_clvp.py +15 -5
- transformers/models/clvp/feature_extraction_clvp.py +7 -10
- transformers/models/clvp/modeling_clvp.py +138 -145
- transformers/models/clvp/number_normalizer.py +1 -2
- transformers/models/clvp/processing_clvp.py +3 -20
- transformers/models/clvp/tokenization_clvp.py +0 -1
- transformers/models/code_llama/tokenization_code_llama.py +3 -6
- transformers/models/codegen/configuration_codegen.py +4 -4
- transformers/models/codegen/modeling_codegen.py +50 -49
- transformers/models/codegen/tokenization_codegen.py +5 -6
- transformers/models/cohere/configuration_cohere.py +25 -30
- transformers/models/cohere/modeling_cohere.py +39 -42
- transformers/models/cohere/modular_cohere.py +27 -31
- transformers/models/cohere/tokenization_cohere.py +5 -6
- transformers/models/cohere2/configuration_cohere2.py +27 -32
- transformers/models/cohere2/modeling_cohere2.py +38 -41
- transformers/models/cohere2/modular_cohere2.py +48 -52
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
- transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
- transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
- transformers/models/colpali/configuration_colpali.py +0 -1
- transformers/models/colpali/modeling_colpali.py +14 -16
- transformers/models/colpali/modular_colpali.py +11 -51
- transformers/models/colpali/processing_colpali.py +14 -52
- transformers/models/colqwen2/modeling_colqwen2.py +27 -28
- transformers/models/colqwen2/modular_colqwen2.py +36 -74
- transformers/models/colqwen2/processing_colqwen2.py +16 -52
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
- transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
- transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
- transformers/models/convbert/configuration_convbert.py +11 -8
- transformers/models/convbert/modeling_convbert.py +85 -87
- transformers/models/convbert/tokenization_convbert.py +0 -1
- transformers/models/convnext/configuration_convnext.py +2 -5
- transformers/models/convnext/image_processing_convnext.py +18 -21
- transformers/models/convnext/image_processing_convnext_fast.py +7 -8
- transformers/models/convnext/modeling_convnext.py +12 -14
- transformers/models/convnextv2/configuration_convnextv2.py +2 -5
- transformers/models/convnextv2/modeling_convnextv2.py +12 -14
- transformers/models/cpm/tokenization_cpm.py +6 -7
- transformers/models/cpm/tokenization_cpm_fast.py +3 -5
- transformers/models/cpmant/configuration_cpmant.py +4 -1
- transformers/models/cpmant/modeling_cpmant.py +38 -40
- transformers/models/cpmant/tokenization_cpmant.py +1 -3
- transformers/models/csm/configuration_csm.py +58 -66
- transformers/models/csm/generation_csm.py +13 -14
- transformers/models/csm/modeling_csm.py +81 -84
- transformers/models/csm/modular_csm.py +56 -58
- transformers/models/csm/processing_csm.py +25 -68
- transformers/models/ctrl/configuration_ctrl.py +16 -1
- transformers/models/ctrl/modeling_ctrl.py +51 -66
- transformers/models/ctrl/tokenization_ctrl.py +0 -1
- transformers/models/cvt/configuration_cvt.py +0 -1
- transformers/models/cvt/modeling_cvt.py +13 -15
- transformers/models/cwm/__init__.py +0 -1
- transformers/models/cwm/configuration_cwm.py +8 -12
- transformers/models/cwm/modeling_cwm.py +36 -38
- transformers/models/cwm/modular_cwm.py +10 -12
- transformers/models/d_fine/configuration_d_fine.py +10 -57
- transformers/models/d_fine/modeling_d_fine.py +786 -927
- transformers/models/d_fine/modular_d_fine.py +339 -417
- transformers/models/dab_detr/configuration_dab_detr.py +22 -49
- transformers/models/dab_detr/modeling_dab_detr.py +79 -77
- transformers/models/dac/configuration_dac.py +0 -1
- transformers/models/dac/feature_extraction_dac.py +6 -9
- transformers/models/dac/modeling_dac.py +22 -24
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
- transformers/models/data2vec/configuration_data2vec_text.py +11 -3
- transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
- transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
- transformers/models/data2vec/modeling_data2vec_text.py +97 -99
- transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
- transformers/models/data2vec/modular_data2vec_audio.py +6 -1
- transformers/models/data2vec/modular_data2vec_text.py +51 -54
- transformers/models/dbrx/configuration_dbrx.py +29 -22
- transformers/models/dbrx/modeling_dbrx.py +45 -48
- transformers/models/dbrx/modular_dbrx.py +37 -39
- transformers/models/deberta/configuration_deberta.py +6 -1
- transformers/models/deberta/modeling_deberta.py +57 -60
- transformers/models/deberta/tokenization_deberta.py +2 -5
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
- transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
- transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
- transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
- transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
- transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
- transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
- transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
- transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
- transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
- transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
- transformers/models/deit/configuration_deit.py +0 -1
- transformers/models/deit/image_processing_deit.py +18 -21
- transformers/models/deit/image_processing_deit_fast.py +0 -1
- transformers/models/deit/modeling_deit.py +27 -25
- transformers/models/depth_anything/configuration_depth_anything.py +12 -43
- transformers/models/depth_anything/modeling_depth_anything.py +10 -11
- transformers/models/depth_pro/configuration_depth_pro.py +0 -1
- transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
- transformers/models/depth_pro/modeling_depth_pro.py +29 -27
- transformers/models/detr/configuration_detr.py +18 -50
- transformers/models/detr/image_processing_detr.py +64 -66
- transformers/models/detr/image_processing_detr_fast.py +33 -34
- transformers/models/detr/modeling_detr.py +748 -789
- transformers/models/dia/configuration_dia.py +9 -15
- transformers/models/dia/feature_extraction_dia.py +6 -9
- transformers/models/dia/generation_dia.py +48 -53
- transformers/models/dia/modeling_dia.py +68 -71
- transformers/models/dia/modular_dia.py +56 -58
- transformers/models/dia/processing_dia.py +39 -29
- transformers/models/dia/tokenization_dia.py +3 -6
- transformers/models/diffllama/configuration_diffllama.py +25 -30
- transformers/models/diffllama/modeling_diffllama.py +45 -53
- transformers/models/diffllama/modular_diffllama.py +18 -25
- transformers/models/dinat/configuration_dinat.py +2 -5
- transformers/models/dinat/modeling_dinat.py +47 -48
- transformers/models/dinov2/configuration_dinov2.py +2 -5
- transformers/models/dinov2/modeling_dinov2.py +20 -21
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
- transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
- transformers/models/distilbert/configuration_distilbert.py +8 -2
- transformers/models/distilbert/modeling_distilbert.py +47 -49
- transformers/models/distilbert/tokenization_distilbert.py +0 -1
- transformers/models/doge/__init__.py +0 -1
- transformers/models/doge/configuration_doge.py +42 -35
- transformers/models/doge/modeling_doge.py +46 -49
- transformers/models/doge/modular_doge.py +77 -68
- transformers/models/donut/configuration_donut_swin.py +0 -1
- transformers/models/donut/image_processing_donut.py +26 -29
- transformers/models/donut/image_processing_donut_fast.py +9 -14
- transformers/models/donut/modeling_donut_swin.py +44 -46
- transformers/models/donut/processing_donut.py +5 -26
- transformers/models/dots1/configuration_dots1.py +43 -36
- transformers/models/dots1/modeling_dots1.py +35 -38
- transformers/models/dots1/modular_dots1.py +0 -1
- transformers/models/dpr/configuration_dpr.py +19 -2
- transformers/models/dpr/modeling_dpr.py +37 -39
- transformers/models/dpr/tokenization_dpr.py +7 -9
- transformers/models/dpr/tokenization_dpr_fast.py +7 -9
- transformers/models/dpt/configuration_dpt.py +23 -66
- transformers/models/dpt/image_processing_dpt.py +65 -66
- transformers/models/dpt/image_processing_dpt_fast.py +18 -19
- transformers/models/dpt/modeling_dpt.py +38 -36
- transformers/models/dpt/modular_dpt.py +14 -15
- transformers/models/edgetam/configuration_edgetam.py +1 -2
- transformers/models/edgetam/modeling_edgetam.py +87 -89
- transformers/models/edgetam/modular_edgetam.py +7 -13
- transformers/models/edgetam_video/__init__.py +0 -1
- transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
- transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
- transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
- transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
- transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
- transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
- transformers/models/efficientnet/configuration_efficientnet.py +0 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
- transformers/models/efficientnet/modeling_efficientnet.py +12 -14
- transformers/models/electra/configuration_electra.py +13 -3
- transformers/models/electra/modeling_electra.py +107 -109
- transformers/models/emu3/configuration_emu3.py +17 -17
- transformers/models/emu3/image_processing_emu3.py +44 -39
- transformers/models/emu3/modeling_emu3.py +143 -109
- transformers/models/emu3/modular_emu3.py +109 -73
- transformers/models/emu3/processing_emu3.py +18 -43
- transformers/models/encodec/configuration_encodec.py +2 -4
- transformers/models/encodec/feature_extraction_encodec.py +10 -13
- transformers/models/encodec/modeling_encodec.py +25 -29
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
- transformers/models/eomt/configuration_eomt.py +12 -14
- transformers/models/eomt/image_processing_eomt.py +53 -55
- transformers/models/eomt/image_processing_eomt_fast.py +18 -19
- transformers/models/eomt/modeling_eomt.py +19 -21
- transformers/models/eomt/modular_eomt.py +28 -30
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -3
- transformers/models/ernie/modeling_ernie.py +127 -162
- transformers/models/ernie/modular_ernie.py +91 -103
- transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
- transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
- transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
- transformers/models/esm/configuration_esm.py +11 -15
- transformers/models/esm/modeling_esm.py +35 -37
- transformers/models/esm/modeling_esmfold.py +43 -50
- transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
- transformers/models/esm/openfold_utils/loss.py +1 -2
- transformers/models/esm/openfold_utils/protein.py +15 -16
- transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
- transformers/models/esm/tokenization_esm.py +2 -4
- transformers/models/evolla/configuration_evolla.py +50 -40
- transformers/models/evolla/modeling_evolla.py +69 -68
- transformers/models/evolla/modular_evolla.py +50 -48
- transformers/models/evolla/processing_evolla.py +23 -35
- transformers/models/exaone4/configuration_exaone4.py +27 -27
- transformers/models/exaone4/modeling_exaone4.py +36 -39
- transformers/models/exaone4/modular_exaone4.py +51 -50
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +31 -26
- transformers/models/falcon/modeling_falcon.py +76 -84
- transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
- transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
- transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
- transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
- transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
- transformers/models/flaubert/configuration_flaubert.py +10 -5
- transformers/models/flaubert/modeling_flaubert.py +125 -129
- transformers/models/flaubert/tokenization_flaubert.py +3 -5
- transformers/models/flava/configuration_flava.py +9 -9
- transformers/models/flava/image_processing_flava.py +66 -67
- transformers/models/flava/image_processing_flava_fast.py +46 -47
- transformers/models/flava/modeling_flava.py +144 -135
- transformers/models/flava/processing_flava.py +2 -12
- transformers/models/flex_olmo/__init__.py +0 -1
- transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
- transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
- transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
- transformers/models/florence2/configuration_florence2.py +4 -1
- transformers/models/florence2/modeling_florence2.py +96 -72
- transformers/models/florence2/modular_florence2.py +100 -107
- transformers/models/florence2/processing_florence2.py +18 -47
- transformers/models/fnet/configuration_fnet.py +6 -2
- transformers/models/fnet/modeling_fnet.py +69 -80
- transformers/models/fnet/tokenization_fnet.py +0 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -5
- transformers/models/focalnet/modeling_focalnet.py +49 -48
- transformers/models/fsmt/configuration_fsmt.py +12 -17
- transformers/models/fsmt/modeling_fsmt.py +47 -48
- transformers/models/fsmt/tokenization_fsmt.py +3 -5
- transformers/models/funnel/configuration_funnel.py +8 -1
- transformers/models/funnel/modeling_funnel.py +91 -93
- transformers/models/funnel/tokenization_funnel.py +2 -5
- transformers/models/fuyu/configuration_fuyu.py +28 -34
- transformers/models/fuyu/image_processing_fuyu.py +29 -31
- transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
- transformers/models/fuyu/modeling_fuyu.py +50 -52
- transformers/models/fuyu/processing_fuyu.py +9 -36
- transformers/models/gemma/configuration_gemma.py +25 -30
- transformers/models/gemma/modeling_gemma.py +36 -38
- transformers/models/gemma/modular_gemma.py +33 -36
- transformers/models/gemma/tokenization_gemma.py +3 -6
- transformers/models/gemma2/configuration_gemma2.py +30 -35
- transformers/models/gemma2/modeling_gemma2.py +38 -41
- transformers/models/gemma2/modular_gemma2.py +63 -67
- transformers/models/gemma3/configuration_gemma3.py +53 -48
- transformers/models/gemma3/image_processing_gemma3.py +29 -31
- transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
- transformers/models/gemma3/modeling_gemma3.py +123 -122
- transformers/models/gemma3/modular_gemma3.py +128 -125
- transformers/models/gemma3/processing_gemma3.py +5 -5
- transformers/models/gemma3n/configuration_gemma3n.py +42 -30
- transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
- transformers/models/gemma3n/modeling_gemma3n.py +166 -147
- transformers/models/gemma3n/modular_gemma3n.py +176 -148
- transformers/models/gemma3n/processing_gemma3n.py +12 -26
- transformers/models/git/configuration_git.py +5 -8
- transformers/models/git/modeling_git.py +115 -127
- transformers/models/git/processing_git.py +2 -14
- transformers/models/glm/configuration_glm.py +26 -30
- transformers/models/glm/modeling_glm.py +36 -39
- transformers/models/glm/modular_glm.py +4 -7
- transformers/models/glm4/configuration_glm4.py +26 -30
- transformers/models/glm4/modeling_glm4.py +39 -41
- transformers/models/glm4/modular_glm4.py +8 -10
- transformers/models/glm46v/configuration_glm46v.py +4 -1
- transformers/models/glm46v/image_processing_glm46v.py +40 -38
- transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
- transformers/models/glm46v/modeling_glm46v.py +138 -93
- transformers/models/glm46v/modular_glm46v.py +5 -3
- transformers/models/glm46v/processing_glm46v.py +7 -41
- transformers/models/glm46v/video_processing_glm46v.py +9 -11
- transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
- transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
- transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
- transformers/models/glm4_moe_lite/__init__.py +28 -0
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
- transformers/models/glm4v/configuration_glm4v.py +25 -24
- transformers/models/glm4v/image_processing_glm4v.py +39 -38
- transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
- transformers/models/glm4v/modeling_glm4v.py +249 -210
- transformers/models/glm4v/modular_glm4v.py +211 -230
- transformers/models/glm4v/processing_glm4v.py +7 -41
- transformers/models/glm4v/video_processing_glm4v.py +9 -11
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
- transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
- transformers/models/glm_image/__init__.py +31 -0
- transformers/models/glm_image/configuration_glm_image.py +358 -0
- transformers/models/glm_image/image_processing_glm_image.py +503 -0
- transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
- transformers/models/glm_image/modeling_glm_image.py +1691 -0
- transformers/models/glm_image/modular_glm_image.py +1640 -0
- transformers/models/glm_image/processing_glm_image.py +265 -0
- transformers/models/glm_ocr/__init__.py +28 -0
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/__init__.py +0 -1
- transformers/models/glmasr/configuration_glmasr.py +0 -1
- transformers/models/glmasr/modeling_glmasr.py +51 -46
- transformers/models/glmasr/modular_glmasr.py +39 -29
- transformers/models/glmasr/processing_glmasr.py +7 -8
- transformers/models/glpn/configuration_glpn.py +0 -1
- transformers/models/glpn/image_processing_glpn.py +11 -12
- transformers/models/glpn/image_processing_glpn_fast.py +11 -12
- transformers/models/glpn/modeling_glpn.py +14 -14
- transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
- transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
- transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
- transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
- transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
- transformers/models/gpt2/configuration_gpt2.py +13 -2
- transformers/models/gpt2/modeling_gpt2.py +111 -113
- transformers/models/gpt2/tokenization_gpt2.py +6 -9
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
- transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
- transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
- transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
- transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
- transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
- transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
- transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
- transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
- transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
- transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
- transformers/models/gptj/configuration_gptj.py +4 -5
- transformers/models/gptj/modeling_gptj.py +85 -88
- transformers/models/granite/configuration_granite.py +28 -33
- transformers/models/granite/modeling_granite.py +43 -45
- transformers/models/granite/modular_granite.py +29 -31
- transformers/models/granite_speech/configuration_granite_speech.py +0 -1
- transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
- transformers/models/granite_speech/modeling_granite_speech.py +84 -60
- transformers/models/granite_speech/processing_granite_speech.py +11 -4
- transformers/models/granitemoe/configuration_granitemoe.py +31 -36
- transformers/models/granitemoe/modeling_granitemoe.py +39 -41
- transformers/models/granitemoe/modular_granitemoe.py +21 -23
- transformers/models/granitemoehybrid/__init__.py +0 -1
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
- transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
- transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
- transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
- transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
- transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
- transformers/models/groupvit/configuration_groupvit.py +4 -2
- transformers/models/groupvit/modeling_groupvit.py +98 -92
- transformers/models/helium/configuration_helium.py +25 -29
- transformers/models/helium/modeling_helium.py +37 -40
- transformers/models/helium/modular_helium.py +3 -7
- transformers/models/herbert/tokenization_herbert.py +4 -6
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
- transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
- transformers/models/hiera/configuration_hiera.py +2 -5
- transformers/models/hiera/modeling_hiera.py +71 -70
- transformers/models/hubert/configuration_hubert.py +4 -2
- transformers/models/hubert/modeling_hubert.py +42 -41
- transformers/models/hubert/modular_hubert.py +8 -11
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
- transformers/models/ibert/configuration_ibert.py +4 -2
- transformers/models/ibert/modeling_ibert.py +60 -62
- transformers/models/ibert/quant_modules.py +0 -1
- transformers/models/idefics/configuration_idefics.py +5 -8
- transformers/models/idefics/image_processing_idefics.py +13 -15
- transformers/models/idefics/modeling_idefics.py +63 -65
- transformers/models/idefics/perceiver.py +1 -3
- transformers/models/idefics/processing_idefics.py +32 -48
- transformers/models/idefics/vision.py +27 -28
- transformers/models/idefics2/configuration_idefics2.py +1 -3
- transformers/models/idefics2/image_processing_idefics2.py +31 -32
- transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
- transformers/models/idefics2/modeling_idefics2.py +126 -106
- transformers/models/idefics2/processing_idefics2.py +10 -68
- transformers/models/idefics3/configuration_idefics3.py +1 -4
- transformers/models/idefics3/image_processing_idefics3.py +42 -43
- transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
- transformers/models/idefics3/modeling_idefics3.py +113 -92
- transformers/models/idefics3/processing_idefics3.py +15 -69
- transformers/models/ijepa/configuration_ijepa.py +0 -1
- transformers/models/ijepa/modeling_ijepa.py +13 -14
- transformers/models/ijepa/modular_ijepa.py +5 -7
- transformers/models/imagegpt/configuration_imagegpt.py +9 -2
- transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
- transformers/models/imagegpt/modeling_imagegpt.py +65 -62
- transformers/models/informer/configuration_informer.py +6 -9
- transformers/models/informer/modeling_informer.py +87 -89
- transformers/models/informer/modular_informer.py +13 -16
- transformers/models/instructblip/configuration_instructblip.py +2 -2
- transformers/models/instructblip/modeling_instructblip.py +104 -79
- transformers/models/instructblip/processing_instructblip.py +10 -36
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
- transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
- transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
- transformers/models/internvl/configuration_internvl.py +5 -1
- transformers/models/internvl/modeling_internvl.py +76 -98
- transformers/models/internvl/modular_internvl.py +45 -59
- transformers/models/internvl/processing_internvl.py +12 -45
- transformers/models/internvl/video_processing_internvl.py +10 -11
- transformers/models/jais2/configuration_jais2.py +25 -29
- transformers/models/jais2/modeling_jais2.py +36 -38
- transformers/models/jais2/modular_jais2.py +20 -22
- transformers/models/jamba/configuration_jamba.py +5 -8
- transformers/models/jamba/modeling_jamba.py +47 -50
- transformers/models/jamba/modular_jamba.py +40 -41
- transformers/models/janus/configuration_janus.py +0 -1
- transformers/models/janus/image_processing_janus.py +37 -39
- transformers/models/janus/image_processing_janus_fast.py +20 -21
- transformers/models/janus/modeling_janus.py +103 -188
- transformers/models/janus/modular_janus.py +122 -83
- transformers/models/janus/processing_janus.py +17 -43
- transformers/models/jetmoe/configuration_jetmoe.py +26 -27
- transformers/models/jetmoe/modeling_jetmoe.py +42 -45
- transformers/models/jetmoe/modular_jetmoe.py +33 -36
- transformers/models/kosmos2/configuration_kosmos2.py +10 -9
- transformers/models/kosmos2/modeling_kosmos2.py +199 -178
- transformers/models/kosmos2/processing_kosmos2.py +40 -55
- transformers/models/kosmos2_5/__init__.py +0 -1
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
- transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
- transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
- transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
- transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
- transformers/models/lasr/configuration_lasr.py +3 -7
- transformers/models/lasr/feature_extraction_lasr.py +10 -12
- transformers/models/lasr/modeling_lasr.py +21 -24
- transformers/models/lasr/modular_lasr.py +11 -13
- transformers/models/lasr/processing_lasr.py +12 -6
- transformers/models/lasr/tokenization_lasr.py +2 -4
- transformers/models/layoutlm/configuration_layoutlm.py +14 -2
- transformers/models/layoutlm/modeling_layoutlm.py +70 -72
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
- transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
- transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
- transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
- transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
- transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
- transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
- transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
- transformers/models/led/configuration_led.py +8 -12
- transformers/models/led/modeling_led.py +113 -267
- transformers/models/levit/configuration_levit.py +0 -1
- transformers/models/levit/image_processing_levit.py +19 -21
- transformers/models/levit/image_processing_levit_fast.py +4 -5
- transformers/models/levit/modeling_levit.py +17 -19
- transformers/models/lfm2/configuration_lfm2.py +27 -30
- transformers/models/lfm2/modeling_lfm2.py +46 -48
- transformers/models/lfm2/modular_lfm2.py +32 -32
- transformers/models/lfm2_moe/__init__.py +0 -1
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
- transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
- transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
- transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
- transformers/models/lightglue/image_processing_lightglue.py +16 -15
- transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
- transformers/models/lightglue/modeling_lightglue.py +31 -33
- transformers/models/lightglue/modular_lightglue.py +31 -31
- transformers/models/lighton_ocr/__init__.py +28 -0
- transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
- transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
- transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
- transformers/models/lilt/configuration_lilt.py +6 -2
- transformers/models/lilt/modeling_lilt.py +53 -55
- transformers/models/llama/configuration_llama.py +26 -31
- transformers/models/llama/modeling_llama.py +35 -38
- transformers/models/llama/tokenization_llama.py +2 -4
- transformers/models/llama4/configuration_llama4.py +87 -69
- transformers/models/llama4/image_processing_llama4_fast.py +11 -12
- transformers/models/llama4/modeling_llama4.py +116 -115
- transformers/models/llama4/processing_llama4.py +33 -57
- transformers/models/llava/configuration_llava.py +10 -1
- transformers/models/llava/image_processing_llava.py +25 -28
- transformers/models/llava/image_processing_llava_fast.py +9 -10
- transformers/models/llava/modeling_llava.py +73 -102
- transformers/models/llava/processing_llava.py +18 -51
- transformers/models/llava_next/configuration_llava_next.py +2 -2
- transformers/models/llava_next/image_processing_llava_next.py +43 -45
- transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
- transformers/models/llava_next/modeling_llava_next.py +103 -104
- transformers/models/llava_next/processing_llava_next.py +18 -47
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
- transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
- transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
- transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
- transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
- transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
- transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
- transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
- transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
- transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
- transformers/models/longcat_flash/__init__.py +0 -1
- transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
- transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
- transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
- transformers/models/longformer/configuration_longformer.py +5 -5
- transformers/models/longformer/modeling_longformer.py +99 -101
- transformers/models/longt5/configuration_longt5.py +9 -7
- transformers/models/longt5/modeling_longt5.py +45 -45
- transformers/models/luke/configuration_luke.py +8 -2
- transformers/models/luke/modeling_luke.py +179 -181
- transformers/models/luke/tokenization_luke.py +99 -105
- transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
- transformers/models/lw_detr/configuration_lw_detr.py +362 -0
- transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
- transformers/models/lw_detr/modular_lw_detr.py +1609 -0
- transformers/models/lxmert/configuration_lxmert.py +16 -1
- transformers/models/lxmert/modeling_lxmert.py +63 -74
- transformers/models/m2m_100/configuration_m2m_100.py +7 -9
- transformers/models/m2m_100/modeling_m2m_100.py +72 -74
- transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
- transformers/models/mamba/configuration_mamba.py +5 -3
- transformers/models/mamba/modeling_mamba.py +61 -70
- transformers/models/mamba2/configuration_mamba2.py +5 -8
- transformers/models/mamba2/modeling_mamba2.py +66 -79
- transformers/models/marian/configuration_marian.py +10 -5
- transformers/models/marian/modeling_marian.py +88 -90
- transformers/models/marian/tokenization_marian.py +6 -6
- transformers/models/markuplm/configuration_markuplm.py +4 -7
- transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
- transformers/models/markuplm/modeling_markuplm.py +63 -65
- transformers/models/markuplm/processing_markuplm.py +31 -38
- transformers/models/markuplm/tokenization_markuplm.py +67 -77
- transformers/models/mask2former/configuration_mask2former.py +14 -52
- transformers/models/mask2former/image_processing_mask2former.py +84 -85
- transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
- transformers/models/mask2former/modeling_mask2former.py +108 -104
- transformers/models/mask2former/modular_mask2former.py +6 -8
- transformers/models/maskformer/configuration_maskformer.py +17 -51
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
- transformers/models/maskformer/image_processing_maskformer.py +84 -85
- transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
- transformers/models/maskformer/modeling_maskformer.py +71 -67
- transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
- transformers/models/mbart/configuration_mbart.py +9 -5
- transformers/models/mbart/modeling_mbart.py +120 -119
- transformers/models/mbart/tokenization_mbart.py +2 -4
- transformers/models/mbart50/tokenization_mbart50.py +3 -5
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
- transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
- transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
- transformers/models/mgp_str/configuration_mgp_str.py +0 -1
- transformers/models/mgp_str/modeling_mgp_str.py +18 -18
- transformers/models/mgp_str/processing_mgp_str.py +3 -20
- transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
- transformers/models/mimi/configuration_mimi.py +42 -40
- transformers/models/mimi/modeling_mimi.py +116 -115
- transformers/models/minimax/__init__.py +0 -1
- transformers/models/minimax/configuration_minimax.py +40 -47
- transformers/models/minimax/modeling_minimax.py +46 -49
- transformers/models/minimax/modular_minimax.py +59 -65
- transformers/models/minimax_m2/__init__.py +28 -0
- transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
- transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
- transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
- transformers/models/ministral/configuration_ministral.py +25 -29
- transformers/models/ministral/modeling_ministral.py +35 -37
- transformers/models/ministral/modular_ministral.py +32 -37
- transformers/models/ministral3/configuration_ministral3.py +23 -26
- transformers/models/ministral3/modeling_ministral3.py +35 -37
- transformers/models/ministral3/modular_ministral3.py +7 -8
- transformers/models/mistral/configuration_mistral.py +24 -29
- transformers/models/mistral/modeling_mistral.py +35 -37
- transformers/models/mistral/modular_mistral.py +14 -15
- transformers/models/mistral3/configuration_mistral3.py +4 -1
- transformers/models/mistral3/modeling_mistral3.py +79 -82
- transformers/models/mistral3/modular_mistral3.py +66 -67
- transformers/models/mixtral/configuration_mixtral.py +32 -38
- transformers/models/mixtral/modeling_mixtral.py +39 -42
- transformers/models/mixtral/modular_mixtral.py +26 -29
- transformers/models/mlcd/configuration_mlcd.py +0 -1
- transformers/models/mlcd/modeling_mlcd.py +17 -17
- transformers/models/mlcd/modular_mlcd.py +16 -16
- transformers/models/mllama/configuration_mllama.py +10 -15
- transformers/models/mllama/image_processing_mllama.py +23 -25
- transformers/models/mllama/image_processing_mllama_fast.py +11 -11
- transformers/models/mllama/modeling_mllama.py +100 -103
- transformers/models/mllama/processing_mllama.py +6 -55
- transformers/models/mluke/tokenization_mluke.py +97 -103
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
- transformers/models/mobilebert/configuration_mobilebert.py +4 -2
- transformers/models/mobilebert/modeling_mobilebert.py +78 -88
- transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
- transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
- transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
- transformers/models/mobilevit/configuration_mobilevit.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
- transformers/models/mobilevit/modeling_mobilevit.py +21 -21
- transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
- transformers/models/modernbert/configuration_modernbert.py +76 -51
- transformers/models/modernbert/modeling_modernbert.py +188 -943
- transformers/models/modernbert/modular_modernbert.py +255 -978
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
- transformers/models/moonshine/configuration_moonshine.py +34 -31
- transformers/models/moonshine/modeling_moonshine.py +70 -72
- transformers/models/moonshine/modular_moonshine.py +91 -86
- transformers/models/moshi/configuration_moshi.py +46 -23
- transformers/models/moshi/modeling_moshi.py +134 -142
- transformers/models/mpnet/configuration_mpnet.py +6 -2
- transformers/models/mpnet/modeling_mpnet.py +55 -57
- transformers/models/mpnet/tokenization_mpnet.py +1 -4
- transformers/models/mpt/configuration_mpt.py +17 -9
- transformers/models/mpt/modeling_mpt.py +58 -60
- transformers/models/mra/configuration_mra.py +8 -2
- transformers/models/mra/modeling_mra.py +54 -56
- transformers/models/mt5/configuration_mt5.py +9 -6
- transformers/models/mt5/modeling_mt5.py +80 -85
- transformers/models/musicgen/configuration_musicgen.py +12 -8
- transformers/models/musicgen/modeling_musicgen.py +114 -116
- transformers/models/musicgen/processing_musicgen.py +3 -21
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
- transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
- transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
- transformers/models/mvp/configuration_mvp.py +8 -5
- transformers/models/mvp/modeling_mvp.py +121 -123
- transformers/models/myt5/tokenization_myt5.py +8 -10
- transformers/models/nanochat/configuration_nanochat.py +5 -8
- transformers/models/nanochat/modeling_nanochat.py +36 -39
- transformers/models/nanochat/modular_nanochat.py +16 -18
- transformers/models/nemotron/configuration_nemotron.py +25 -30
- transformers/models/nemotron/modeling_nemotron.py +53 -66
- transformers/models/nllb/tokenization_nllb.py +14 -14
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
- transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
- transformers/models/nougat/image_processing_nougat.py +29 -32
- transformers/models/nougat/image_processing_nougat_fast.py +12 -13
- transformers/models/nougat/processing_nougat.py +37 -39
- transformers/models/nougat/tokenization_nougat.py +5 -7
- transformers/models/nystromformer/configuration_nystromformer.py +8 -2
- transformers/models/nystromformer/modeling_nystromformer.py +61 -63
- transformers/models/olmo/configuration_olmo.py +23 -28
- transformers/models/olmo/modeling_olmo.py +35 -38
- transformers/models/olmo/modular_olmo.py +8 -12
- transformers/models/olmo2/configuration_olmo2.py +27 -32
- transformers/models/olmo2/modeling_olmo2.py +36 -39
- transformers/models/olmo2/modular_olmo2.py +36 -38
- transformers/models/olmo3/__init__.py +0 -1
- transformers/models/olmo3/configuration_olmo3.py +30 -34
- transformers/models/olmo3/modeling_olmo3.py +35 -38
- transformers/models/olmo3/modular_olmo3.py +44 -47
- transformers/models/olmoe/configuration_olmoe.py +29 -33
- transformers/models/olmoe/modeling_olmoe.py +41 -43
- transformers/models/olmoe/modular_olmoe.py +15 -16
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
- transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
- transformers/models/oneformer/configuration_oneformer.py +11 -51
- transformers/models/oneformer/image_processing_oneformer.py +83 -84
- transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
- transformers/models/oneformer/modeling_oneformer.py +137 -133
- transformers/models/oneformer/processing_oneformer.py +28 -43
- transformers/models/openai/configuration_openai.py +16 -1
- transformers/models/openai/modeling_openai.py +50 -51
- transformers/models/openai/tokenization_openai.py +2 -5
- transformers/models/opt/configuration_opt.py +6 -7
- transformers/models/opt/modeling_opt.py +79 -80
- transformers/models/ovis2/__init__.py +0 -1
- transformers/models/ovis2/configuration_ovis2.py +4 -1
- transformers/models/ovis2/image_processing_ovis2.py +22 -24
- transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
- transformers/models/ovis2/modeling_ovis2.py +99 -142
- transformers/models/ovis2/modular_ovis2.py +82 -45
- transformers/models/ovis2/processing_ovis2.py +12 -40
- transformers/models/owlv2/configuration_owlv2.py +4 -2
- transformers/models/owlv2/image_processing_owlv2.py +20 -21
- transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
- transformers/models/owlv2/modeling_owlv2.py +122 -114
- transformers/models/owlv2/modular_owlv2.py +11 -12
- transformers/models/owlv2/processing_owlv2.py +20 -49
- transformers/models/owlvit/configuration_owlvit.py +4 -2
- transformers/models/owlvit/image_processing_owlvit.py +21 -22
- transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
- transformers/models/owlvit/modeling_owlvit.py +121 -113
- transformers/models/owlvit/processing_owlvit.py +20 -48
- transformers/models/paddleocr_vl/__init__.py +0 -1
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
- transformers/models/paligemma/configuration_paligemma.py +4 -1
- transformers/models/paligemma/modeling_paligemma.py +81 -79
- transformers/models/paligemma/processing_paligemma.py +13 -66
- transformers/models/parakeet/configuration_parakeet.py +3 -8
- transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
- transformers/models/parakeet/modeling_parakeet.py +21 -25
- transformers/models/parakeet/modular_parakeet.py +19 -21
- transformers/models/parakeet/processing_parakeet.py +12 -5
- transformers/models/parakeet/tokenization_parakeet.py +2 -4
- transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
- transformers/models/patchtst/configuration_patchtst.py +6 -9
- transformers/models/patchtst/modeling_patchtst.py +75 -77
- transformers/models/pe_audio/__init__.py +0 -1
- transformers/models/pe_audio/configuration_pe_audio.py +14 -16
- transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
- transformers/models/pe_audio/modeling_pe_audio.py +30 -31
- transformers/models/pe_audio/modular_pe_audio.py +17 -18
- transformers/models/pe_audio/processing_pe_audio.py +0 -1
- transformers/models/pe_audio_video/__init__.py +0 -1
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
- transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
- transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
- transformers/models/pe_video/__init__.py +0 -1
- transformers/models/pe_video/configuration_pe_video.py +14 -16
- transformers/models/pe_video/modeling_pe_video.py +57 -46
- transformers/models/pe_video/modular_pe_video.py +47 -35
- transformers/models/pe_video/video_processing_pe_video.py +2 -4
- transformers/models/pegasus/configuration_pegasus.py +8 -6
- transformers/models/pegasus/modeling_pegasus.py +67 -69
- transformers/models/pegasus/tokenization_pegasus.py +1 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
- transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
- transformers/models/perceiver/configuration_perceiver.py +0 -1
- transformers/models/perceiver/image_processing_perceiver.py +22 -25
- transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
- transformers/models/perceiver/modeling_perceiver.py +152 -145
- transformers/models/perceiver/tokenization_perceiver.py +3 -6
- transformers/models/perception_lm/configuration_perception_lm.py +0 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
- transformers/models/perception_lm/modeling_perception_lm.py +64 -67
- transformers/models/perception_lm/modular_perception_lm.py +58 -58
- transformers/models/perception_lm/processing_perception_lm.py +13 -47
- transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
- transformers/models/persimmon/configuration_persimmon.py +23 -28
- transformers/models/persimmon/modeling_persimmon.py +44 -47
- transformers/models/phi/configuration_phi.py +27 -28
- transformers/models/phi/modeling_phi.py +39 -41
- transformers/models/phi/modular_phi.py +26 -26
- transformers/models/phi3/configuration_phi3.py +32 -37
- transformers/models/phi3/modeling_phi3.py +37 -40
- transformers/models/phi3/modular_phi3.py +16 -20
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
- transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
- transformers/models/phimoe/configuration_phimoe.py +31 -36
- transformers/models/phimoe/modeling_phimoe.py +50 -77
- transformers/models/phimoe/modular_phimoe.py +12 -8
- transformers/models/phobert/tokenization_phobert.py +4 -6
- transformers/models/pix2struct/configuration_pix2struct.py +12 -10
- transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
- transformers/models/pix2struct/modeling_pix2struct.py +56 -52
- transformers/models/pix2struct/processing_pix2struct.py +5 -26
- transformers/models/pixio/__init__.py +0 -1
- transformers/models/pixio/configuration_pixio.py +2 -5
- transformers/models/pixio/modeling_pixio.py +16 -17
- transformers/models/pixio/modular_pixio.py +7 -8
- transformers/models/pixtral/configuration_pixtral.py +11 -14
- transformers/models/pixtral/image_processing_pixtral.py +26 -28
- transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
- transformers/models/pixtral/modeling_pixtral.py +31 -37
- transformers/models/pixtral/processing_pixtral.py +18 -52
- transformers/models/plbart/configuration_plbart.py +8 -6
- transformers/models/plbart/modeling_plbart.py +109 -109
- transformers/models/plbart/modular_plbart.py +31 -33
- transformers/models/plbart/tokenization_plbart.py +4 -5
- transformers/models/poolformer/configuration_poolformer.py +0 -1
- transformers/models/poolformer/image_processing_poolformer.py +21 -24
- transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
- transformers/models/poolformer/modeling_poolformer.py +10 -12
- transformers/models/pop2piano/configuration_pop2piano.py +7 -7
- transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
- transformers/models/pop2piano/modeling_pop2piano.py +24 -24
- transformers/models/pop2piano/processing_pop2piano.py +25 -33
- transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
- transformers/models/prophetnet/configuration_prophetnet.py +37 -38
- transformers/models/prophetnet/modeling_prophetnet.py +121 -153
- transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
- transformers/models/pvt/configuration_pvt.py +0 -1
- transformers/models/pvt/image_processing_pvt.py +24 -27
- transformers/models/pvt/image_processing_pvt_fast.py +1 -2
- transformers/models/pvt/modeling_pvt.py +19 -21
- transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
- transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
- transformers/models/qwen2/configuration_qwen2.py +32 -25
- transformers/models/qwen2/modeling_qwen2.py +35 -37
- transformers/models/qwen2/modular_qwen2.py +14 -15
- transformers/models/qwen2/tokenization_qwen2.py +2 -9
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
- transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
- transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
- transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
- transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
- transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
- transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
- transformers/models/qwen3/configuration_qwen3.py +34 -27
- transformers/models/qwen3/modeling_qwen3.py +35 -38
- transformers/models/qwen3/modular_qwen3.py +7 -9
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
- transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
- transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
- transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
- transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
- transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
- transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
- transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
- transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
- transformers/models/rag/configuration_rag.py +6 -7
- transformers/models/rag/modeling_rag.py +119 -121
- transformers/models/rag/retrieval_rag.py +3 -5
- transformers/models/rag/tokenization_rag.py +0 -50
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
- transformers/models/reformer/configuration_reformer.py +7 -8
- transformers/models/reformer/modeling_reformer.py +67 -68
- transformers/models/reformer/tokenization_reformer.py +3 -6
- transformers/models/regnet/configuration_regnet.py +0 -1
- transformers/models/regnet/modeling_regnet.py +7 -9
- transformers/models/rembert/configuration_rembert.py +8 -2
- transformers/models/rembert/modeling_rembert.py +108 -132
- transformers/models/rembert/tokenization_rembert.py +1 -4
- transformers/models/resnet/configuration_resnet.py +2 -5
- transformers/models/resnet/modeling_resnet.py +14 -15
- transformers/models/roberta/configuration_roberta.py +11 -3
- transformers/models/roberta/modeling_roberta.py +97 -99
- transformers/models/roberta/modular_roberta.py +55 -58
- transformers/models/roberta/tokenization_roberta.py +2 -5
- transformers/models/roberta/tokenization_roberta_old.py +2 -4
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
- transformers/models/roc_bert/configuration_roc_bert.py +8 -2
- transformers/models/roc_bert/modeling_roc_bert.py +125 -162
- transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
- transformers/models/roformer/configuration_roformer.py +13 -3
- transformers/models/roformer/modeling_roformer.py +79 -95
- transformers/models/roformer/tokenization_roformer.py +3 -6
- transformers/models/roformer/tokenization_utils.py +0 -1
- transformers/models/rt_detr/configuration_rt_detr.py +8 -50
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
- transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
- transformers/models/rt_detr/modeling_rt_detr.py +643 -804
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
- transformers/models/rt_detr/modular_rt_detr.py +1522 -20
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
- transformers/models/rwkv/configuration_rwkv.py +2 -4
- transformers/models/rwkv/modeling_rwkv.py +29 -54
- transformers/models/sam/configuration_sam.py +2 -1
- transformers/models/sam/image_processing_sam.py +59 -60
- transformers/models/sam/image_processing_sam_fast.py +25 -26
- transformers/models/sam/modeling_sam.py +46 -43
- transformers/models/sam/processing_sam.py +39 -27
- transformers/models/sam2/configuration_sam2.py +1 -2
- transformers/models/sam2/image_processing_sam2_fast.py +14 -15
- transformers/models/sam2/modeling_sam2.py +96 -94
- transformers/models/sam2/modular_sam2.py +85 -94
- transformers/models/sam2/processing_sam2.py +31 -47
- transformers/models/sam2_video/configuration_sam2_video.py +0 -1
- transformers/models/sam2_video/modeling_sam2_video.py +114 -116
- transformers/models/sam2_video/modular_sam2_video.py +72 -89
- transformers/models/sam2_video/processing_sam2_video.py +49 -66
- transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
- transformers/models/sam3/configuration_sam3.py +0 -1
- transformers/models/sam3/image_processing_sam3_fast.py +17 -20
- transformers/models/sam3/modeling_sam3.py +94 -100
- transformers/models/sam3/modular_sam3.py +3 -8
- transformers/models/sam3/processing_sam3.py +37 -52
- transformers/models/sam3_tracker/__init__.py +0 -1
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
- transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
- transformers/models/sam3_tracker_video/__init__.py +0 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
- transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
- transformers/models/sam3_video/configuration_sam3_video.py +0 -1
- transformers/models/sam3_video/modeling_sam3_video.py +56 -45
- transformers/models/sam3_video/processing_sam3_video.py +25 -45
- transformers/models/sam_hq/__init__.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +2 -1
- transformers/models/sam_hq/modeling_sam_hq.py +52 -50
- transformers/models/sam_hq/modular_sam_hq.py +23 -25
- transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
- transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
- transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
- transformers/models/seed_oss/configuration_seed_oss.py +30 -34
- transformers/models/seed_oss/modeling_seed_oss.py +34 -36
- transformers/models/seed_oss/modular_seed_oss.py +6 -7
- transformers/models/segformer/configuration_segformer.py +0 -10
- transformers/models/segformer/image_processing_segformer.py +39 -42
- transformers/models/segformer/image_processing_segformer_fast.py +11 -12
- transformers/models/segformer/modeling_segformer.py +28 -28
- transformers/models/segformer/modular_segformer.py +8 -9
- transformers/models/seggpt/configuration_seggpt.py +0 -1
- transformers/models/seggpt/image_processing_seggpt.py +38 -41
- transformers/models/seggpt/modeling_seggpt.py +48 -38
- transformers/models/sew/configuration_sew.py +4 -2
- transformers/models/sew/modeling_sew.py +42 -40
- transformers/models/sew/modular_sew.py +12 -13
- transformers/models/sew_d/configuration_sew_d.py +4 -2
- transformers/models/sew_d/modeling_sew_d.py +32 -31
- transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
- transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
- transformers/models/siglip/configuration_siglip.py +4 -2
- transformers/models/siglip/image_processing_siglip.py +17 -20
- transformers/models/siglip/image_processing_siglip_fast.py +0 -1
- transformers/models/siglip/modeling_siglip.py +65 -110
- transformers/models/siglip/processing_siglip.py +2 -14
- transformers/models/siglip/tokenization_siglip.py +6 -7
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2.py +15 -16
- transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
- transformers/models/siglip2/modeling_siglip2.py +89 -130
- transformers/models/siglip2/modular_siglip2.py +95 -48
- transformers/models/siglip2/processing_siglip2.py +2 -14
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +29 -32
- transformers/models/smollm3/modeling_smollm3.py +35 -38
- transformers/models/smollm3/modular_smollm3.py +36 -38
- transformers/models/smolvlm/configuration_smolvlm.py +2 -4
- transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
- transformers/models/smolvlm/modeling_smolvlm.py +124 -96
- transformers/models/smolvlm/modular_smolvlm.py +50 -39
- transformers/models/smolvlm/processing_smolvlm.py +15 -76
- transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
- transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
- transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
- transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
- transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
- transformers/models/speecht5/configuration_speecht5.py +7 -9
- transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
- transformers/models/speecht5/modeling_speecht5.py +172 -174
- transformers/models/speecht5/number_normalizer.py +0 -1
- transformers/models/speecht5/processing_speecht5.py +3 -37
- transformers/models/speecht5/tokenization_speecht5.py +4 -5
- transformers/models/splinter/configuration_splinter.py +6 -7
- transformers/models/splinter/modeling_splinter.py +62 -59
- transformers/models/splinter/tokenization_splinter.py +2 -4
- transformers/models/squeezebert/configuration_squeezebert.py +14 -2
- transformers/models/squeezebert/modeling_squeezebert.py +60 -62
- transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
- transformers/models/stablelm/configuration_stablelm.py +28 -29
- transformers/models/stablelm/modeling_stablelm.py +44 -47
- transformers/models/starcoder2/configuration_starcoder2.py +30 -27
- transformers/models/starcoder2/modeling_starcoder2.py +38 -41
- transformers/models/starcoder2/modular_starcoder2.py +17 -19
- transformers/models/superglue/configuration_superglue.py +7 -3
- transformers/models/superglue/image_processing_superglue.py +15 -15
- transformers/models/superglue/image_processing_superglue_fast.py +8 -8
- transformers/models/superglue/modeling_superglue.py +41 -37
- transformers/models/superpoint/image_processing_superpoint.py +15 -15
- transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
- transformers/models/superpoint/modeling_superpoint.py +17 -16
- transformers/models/swiftformer/configuration_swiftformer.py +0 -1
- transformers/models/swiftformer/modeling_swiftformer.py +12 -14
- transformers/models/swin/configuration_swin.py +2 -5
- transformers/models/swin/modeling_swin.py +69 -78
- transformers/models/swin2sr/configuration_swin2sr.py +0 -1
- transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
- transformers/models/swin2sr/modeling_swin2sr.py +30 -30
- transformers/models/swinv2/configuration_swinv2.py +2 -5
- transformers/models/swinv2/modeling_swinv2.py +65 -74
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
- transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
- transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
- transformers/models/t5/configuration_t5.py +9 -9
- transformers/models/t5/modeling_t5.py +80 -85
- transformers/models/t5/tokenization_t5.py +1 -3
- transformers/models/t5gemma/configuration_t5gemma.py +43 -59
- transformers/models/t5gemma/modeling_t5gemma.py +105 -108
- transformers/models/t5gemma/modular_t5gemma.py +128 -142
- transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
- transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
- transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
- transformers/models/table_transformer/configuration_table_transformer.py +18 -50
- transformers/models/table_transformer/modeling_table_transformer.py +73 -101
- transformers/models/tapas/configuration_tapas.py +12 -2
- transformers/models/tapas/modeling_tapas.py +65 -67
- transformers/models/tapas/tokenization_tapas.py +116 -153
- transformers/models/textnet/configuration_textnet.py +4 -7
- transformers/models/textnet/image_processing_textnet.py +22 -25
- transformers/models/textnet/image_processing_textnet_fast.py +8 -9
- transformers/models/textnet/modeling_textnet.py +28 -28
- transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
- transformers/models/timesfm/configuration_timesfm.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +22 -25
- transformers/models/timesfm/modular_timesfm.py +21 -24
- transformers/models/timesformer/configuration_timesformer.py +0 -1
- transformers/models/timesformer/modeling_timesformer.py +13 -16
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
- transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
- transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
- transformers/models/trocr/configuration_trocr.py +11 -8
- transformers/models/trocr/modeling_trocr.py +42 -42
- transformers/models/trocr/processing_trocr.py +5 -25
- transformers/models/tvp/configuration_tvp.py +10 -36
- transformers/models/tvp/image_processing_tvp.py +50 -52
- transformers/models/tvp/image_processing_tvp_fast.py +15 -15
- transformers/models/tvp/modeling_tvp.py +26 -28
- transformers/models/tvp/processing_tvp.py +2 -14
- transformers/models/udop/configuration_udop.py +16 -8
- transformers/models/udop/modeling_udop.py +73 -72
- transformers/models/udop/processing_udop.py +7 -26
- transformers/models/udop/tokenization_udop.py +80 -93
- transformers/models/umt5/configuration_umt5.py +8 -7
- transformers/models/umt5/modeling_umt5.py +87 -84
- transformers/models/unispeech/configuration_unispeech.py +4 -2
- transformers/models/unispeech/modeling_unispeech.py +54 -53
- transformers/models/unispeech/modular_unispeech.py +20 -22
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
- transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
- transformers/models/univnet/feature_extraction_univnet.py +14 -14
- transformers/models/univnet/modeling_univnet.py +7 -8
- transformers/models/upernet/configuration_upernet.py +8 -36
- transformers/models/upernet/modeling_upernet.py +11 -14
- transformers/models/vaultgemma/__init__.py +0 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
- transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
- transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
- transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
- transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
- transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
- transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
- transformers/models/video_llava/configuration_video_llava.py +4 -1
- transformers/models/video_llava/image_processing_video_llava.py +35 -38
- transformers/models/video_llava/modeling_video_llava.py +139 -143
- transformers/models/video_llava/processing_video_llava.py +38 -78
- transformers/models/video_llava/video_processing_video_llava.py +0 -1
- transformers/models/videomae/configuration_videomae.py +0 -1
- transformers/models/videomae/image_processing_videomae.py +31 -34
- transformers/models/videomae/modeling_videomae.py +17 -20
- transformers/models/videomae/video_processing_videomae.py +0 -1
- transformers/models/vilt/configuration_vilt.py +4 -2
- transformers/models/vilt/image_processing_vilt.py +29 -30
- transformers/models/vilt/image_processing_vilt_fast.py +15 -16
- transformers/models/vilt/modeling_vilt.py +103 -90
- transformers/models/vilt/processing_vilt.py +2 -14
- transformers/models/vipllava/configuration_vipllava.py +4 -1
- transformers/models/vipllava/modeling_vipllava.py +92 -67
- transformers/models/vipllava/modular_vipllava.py +78 -54
- transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
- transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
- transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
- transformers/models/visual_bert/configuration_visual_bert.py +6 -2
- transformers/models/visual_bert/modeling_visual_bert.py +90 -92
- transformers/models/vit/configuration_vit.py +2 -3
- transformers/models/vit/image_processing_vit.py +19 -22
- transformers/models/vit/image_processing_vit_fast.py +0 -1
- transformers/models/vit/modeling_vit.py +20 -20
- transformers/models/vit_mae/configuration_vit_mae.py +0 -1
- transformers/models/vit_mae/modeling_vit_mae.py +32 -30
- transformers/models/vit_msn/configuration_vit_msn.py +0 -1
- transformers/models/vit_msn/modeling_vit_msn.py +21 -19
- transformers/models/vitdet/configuration_vitdet.py +2 -5
- transformers/models/vitdet/modeling_vitdet.py +14 -17
- transformers/models/vitmatte/configuration_vitmatte.py +7 -39
- transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
- transformers/models/vitmatte/modeling_vitmatte.py +10 -12
- transformers/models/vitpose/configuration_vitpose.py +7 -47
- transformers/models/vitpose/image_processing_vitpose.py +24 -25
- transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
- transformers/models/vitpose/modeling_vitpose.py +15 -15
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
- transformers/models/vits/configuration_vits.py +4 -1
- transformers/models/vits/modeling_vits.py +43 -42
- transformers/models/vits/tokenization_vits.py +3 -4
- transformers/models/vivit/configuration_vivit.py +0 -1
- transformers/models/vivit/image_processing_vivit.py +36 -39
- transformers/models/vivit/modeling_vivit.py +9 -11
- transformers/models/vjepa2/__init__.py +0 -1
- transformers/models/vjepa2/configuration_vjepa2.py +0 -1
- transformers/models/vjepa2/modeling_vjepa2.py +39 -41
- transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
- transformers/models/voxtral/__init__.py +0 -1
- transformers/models/voxtral/configuration_voxtral.py +0 -2
- transformers/models/voxtral/modeling_voxtral.py +41 -48
- transformers/models/voxtral/modular_voxtral.py +35 -38
- transformers/models/voxtral/processing_voxtral.py +25 -48
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
- transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
- transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
- transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
- transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
- transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
- transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
- transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
- transformers/models/wavlm/configuration_wavlm.py +4 -2
- transformers/models/wavlm/modeling_wavlm.py +49 -49
- transformers/models/wavlm/modular_wavlm.py +4 -5
- transformers/models/whisper/configuration_whisper.py +6 -5
- transformers/models/whisper/english_normalizer.py +3 -4
- transformers/models/whisper/feature_extraction_whisper.py +9 -24
- transformers/models/whisper/generation_whisper.py +26 -49
- transformers/models/whisper/modeling_whisper.py +71 -73
- transformers/models/whisper/processing_whisper.py +3 -20
- transformers/models/whisper/tokenization_whisper.py +9 -30
- transformers/models/x_clip/configuration_x_clip.py +4 -2
- transformers/models/x_clip/modeling_x_clip.py +94 -96
- transformers/models/x_clip/processing_x_clip.py +2 -14
- transformers/models/xcodec/configuration_xcodec.py +4 -6
- transformers/models/xcodec/modeling_xcodec.py +15 -17
- transformers/models/xglm/configuration_xglm.py +9 -8
- transformers/models/xglm/modeling_xglm.py +49 -55
- transformers/models/xglm/tokenization_xglm.py +1 -4
- transformers/models/xlm/configuration_xlm.py +10 -8
- transformers/models/xlm/modeling_xlm.py +127 -131
- transformers/models/xlm/tokenization_xlm.py +3 -5
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
- transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
- transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
- transformers/models/xlnet/configuration_xlnet.py +3 -12
- transformers/models/xlnet/modeling_xlnet.py +149 -162
- transformers/models/xlnet/tokenization_xlnet.py +1 -4
- transformers/models/xlstm/configuration_xlstm.py +8 -12
- transformers/models/xlstm/modeling_xlstm.py +61 -96
- transformers/models/xmod/configuration_xmod.py +11 -3
- transformers/models/xmod/modeling_xmod.py +111 -116
- transformers/models/yolos/configuration_yolos.py +0 -1
- transformers/models/yolos/image_processing_yolos.py +60 -62
- transformers/models/yolos/image_processing_yolos_fast.py +42 -45
- transformers/models/yolos/modeling_yolos.py +19 -21
- transformers/models/yolos/modular_yolos.py +17 -19
- transformers/models/yoso/configuration_yoso.py +8 -2
- transformers/models/yoso/modeling_yoso.py +60 -62
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -8
- transformers/models/zamba/modeling_zamba.py +93 -125
- transformers/models/zamba2/configuration_zamba2.py +44 -50
- transformers/models/zamba2/modeling_zamba2.py +137 -165
- transformers/models/zamba2/modular_zamba2.py +79 -74
- transformers/models/zoedepth/configuration_zoedepth.py +17 -41
- transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
- transformers/models/zoedepth/modeling_zoedepth.py +19 -19
- transformers/pipelines/__init__.py +47 -106
- transformers/pipelines/any_to_any.py +15 -23
- transformers/pipelines/audio_utils.py +1 -2
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +13 -17
- transformers/pipelines/image_text_to_text.py +1 -2
- transformers/pipelines/question_answering.py +4 -43
- transformers/pipelines/text_classification.py +1 -14
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/pipelines/token_classification.py +1 -22
- transformers/pipelines/video_classification.py +1 -9
- transformers/pipelines/zero_shot_audio_classification.py +0 -1
- transformers/pipelines/zero_shot_classification.py +0 -6
- transformers/pipelines/zero_shot_image_classification.py +0 -7
- transformers/processing_utils.py +128 -137
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/base.py +10 -0
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_quark.py +0 -1
- transformers/quantizers/quantizer_torchao.py +3 -19
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +6 -65
- transformers/tokenization_mistral_common.py +563 -903
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +228 -341
- transformers/tokenization_utils_sentencepiece.py +5 -6
- transformers/tokenization_utils_tokenizers.py +36 -7
- transformers/trainer.py +30 -41
- transformers/trainer_jit_checkpoint.py +1 -2
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +414 -420
- transformers/utils/__init__.py +1 -4
- transformers/utils/attention_visualizer.py +1 -1
- transformers/utils/auto_docstring.py +567 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/dummy_pt_objects.py +0 -42
- transformers/utils/generic.py +70 -34
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +135 -107
- transformers/utils/quantization_config.py +8 -31
- transformers/video_processing_utils.py +24 -25
- transformers/video_utils.py +21 -23
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
- transformers-5.1.0.dist-info/RECORD +2092 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -229
- transformers-5.0.0rc2.dist-info/RECORD +0 -2042
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -14,18 +14,17 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import math
|
|
17
|
-
from
|
|
18
|
-
from typing import Literal, Optional, Union
|
|
17
|
+
from typing import Literal, Optional
|
|
19
18
|
|
|
20
19
|
import torch
|
|
21
|
-
import torch.nn.functional as F
|
|
22
20
|
from torch import nn
|
|
23
21
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
24
22
|
|
|
25
23
|
from ... import initialization as init
|
|
26
24
|
from ...activations import ACT2FN
|
|
27
25
|
from ...configuration_utils import PreTrainedConfig, layer_type_validation
|
|
28
|
-
from ...
|
|
26
|
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
|
27
|
+
from ...masking_utils import create_bidirectional_mask, create_bidirectional_sliding_window_mask
|
|
29
28
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
30
29
|
from ...modeling_outputs import (
|
|
31
30
|
BaseModelOutput,
|
|
@@ -36,18 +35,12 @@ from ...modeling_outputs import (
|
|
|
36
35
|
TokenClassifierOutput,
|
|
37
36
|
)
|
|
38
37
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
|
|
39
|
-
from ...modeling_utils import PreTrainedModel
|
|
40
|
-
from ...
|
|
41
|
-
from ...utils
|
|
42
|
-
from
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
if is_flash_attn_2_available():
|
|
46
|
-
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
|
47
|
-
from flash_attn.layers.rotary import RotaryEmbedding
|
|
48
|
-
from flash_attn.ops.triton.rotary import apply_rotary
|
|
49
|
-
else:
|
|
50
|
-
RotaryEmbedding = object
|
|
38
|
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
|
+
from ...processing_utils import Unpack
|
|
40
|
+
from ...utils import TransformersKwargs, auto_docstring, logging
|
|
41
|
+
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
42
|
+
from ..align.modeling_align import eager_attention_forward
|
|
43
|
+
from ..gemma3.modeling_gemma3 import Gemma3RotaryEmbedding, rotate_half
|
|
51
44
|
|
|
52
45
|
|
|
53
46
|
logger = logging.get_logger(__name__)
|
|
@@ -104,10 +97,9 @@ class ModernBertConfig(PreTrainedConfig):
|
|
|
104
97
|
The dropout ratio for the attention probabilities.
|
|
105
98
|
layer_types (`list`, *optional*):
|
|
106
99
|
Attention pattern for each layer.
|
|
107
|
-
rope_parameters (`
|
|
108
|
-
Dictionary
|
|
109
|
-
|
|
110
|
-
with longer `max_position_embeddings`.
|
|
100
|
+
rope_parameters (`dict`, *optional*):
|
|
101
|
+
Dictionary mapping attention patterns (`"full_attention"`, `"sliding_attention"`) to `RopeParameters`.
|
|
102
|
+
Each value should be a dictionary containing `rope_type` and optional scaling parameters.
|
|
111
103
|
local_attention (`int`, *optional*, defaults to 128):
|
|
112
104
|
The window size for local attention.
|
|
113
105
|
embedding_dropout (`float`, *optional*, defaults to 0.0):
|
|
@@ -137,10 +129,9 @@ class ModernBertConfig(PreTrainedConfig):
|
|
|
137
129
|
Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
|
|
138
130
|
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
|
|
139
131
|
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
|
|
140
|
-
be faster in some scenarios.
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
|
|
132
|
+
be faster in some scenarios. This argument is deprecated and will be removed in a future version.
|
|
133
|
+
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
|
134
|
+
Whether to tie weight embeddings
|
|
144
135
|
|
|
145
136
|
Examples:
|
|
146
137
|
|
|
@@ -161,44 +152,59 @@ class ModernBertConfig(PreTrainedConfig):
|
|
|
161
152
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
162
153
|
default_theta = {"global": 160_000.0, "local": 10_000.0}
|
|
163
154
|
|
|
155
|
+
def __setattr__(self, name, value):
|
|
156
|
+
if name == "reference_compile" and value is not None:
|
|
157
|
+
logger.warning_once(
|
|
158
|
+
"The `reference_compile` argument is deprecated and will be removed in `transformers v5.2.0`"
|
|
159
|
+
"Use `torch.compile()` directly on the model instead."
|
|
160
|
+
)
|
|
161
|
+
value = None
|
|
162
|
+
super().__setattr__(name, value)
|
|
163
|
+
|
|
164
164
|
def __init__(
|
|
165
165
|
self,
|
|
166
|
-
vocab_size:
|
|
167
|
-
hidden_size:
|
|
168
|
-
intermediate_size:
|
|
169
|
-
num_hidden_layers:
|
|
170
|
-
num_attention_heads:
|
|
171
|
-
hidden_activation:
|
|
172
|
-
max_position_embeddings:
|
|
173
|
-
initializer_range:
|
|
174
|
-
initializer_cutoff_factor:
|
|
175
|
-
norm_eps:
|
|
176
|
-
norm_bias:
|
|
177
|
-
pad_token_id:
|
|
178
|
-
eos_token_id:
|
|
179
|
-
bos_token_id:
|
|
180
|
-
cls_token_id:
|
|
181
|
-
sep_token_id:
|
|
182
|
-
attention_bias:
|
|
183
|
-
attention_dropout:
|
|
184
|
-
layer_types:
|
|
185
|
-
rope_parameters:
|
|
186
|
-
local_attention:
|
|
187
|
-
embedding_dropout:
|
|
188
|
-
mlp_bias:
|
|
189
|
-
mlp_dropout:
|
|
190
|
-
decoder_bias:
|
|
166
|
+
vocab_size: int | None = 50368,
|
|
167
|
+
hidden_size: int | None = 768,
|
|
168
|
+
intermediate_size: int | None = 1152,
|
|
169
|
+
num_hidden_layers: int | None = 22,
|
|
170
|
+
num_attention_heads: int | None = 12,
|
|
171
|
+
hidden_activation: str | None = "gelu",
|
|
172
|
+
max_position_embeddings: int | None = 8192,
|
|
173
|
+
initializer_range: float | None = 0.02,
|
|
174
|
+
initializer_cutoff_factor: float | None = 2.0,
|
|
175
|
+
norm_eps: float | None = 1e-5,
|
|
176
|
+
norm_bias: bool | None = False,
|
|
177
|
+
pad_token_id: int | None = 50283,
|
|
178
|
+
eos_token_id: int | None = 50282,
|
|
179
|
+
bos_token_id: int | None = 50281,
|
|
180
|
+
cls_token_id: int | None = 50281,
|
|
181
|
+
sep_token_id: int | None = 50282,
|
|
182
|
+
attention_bias: bool | None = False,
|
|
183
|
+
attention_dropout: float | None = 0.0,
|
|
184
|
+
layer_types: list[str] | None = None,
|
|
185
|
+
rope_parameters: dict[Literal["full_attention", "sliding_attention"], RopeParameters] | None = None,
|
|
186
|
+
local_attention: int | None = 128,
|
|
187
|
+
embedding_dropout: float | None = 0.0,
|
|
188
|
+
mlp_bias: bool | None = False,
|
|
189
|
+
mlp_dropout: float | None = 0.0,
|
|
190
|
+
decoder_bias: bool | None = True,
|
|
191
191
|
classifier_pooling: Literal["cls", "mean"] = "cls",
|
|
192
|
-
classifier_dropout:
|
|
193
|
-
classifier_bias:
|
|
194
|
-
classifier_activation:
|
|
195
|
-
deterministic_flash_attn:
|
|
196
|
-
sparse_prediction:
|
|
197
|
-
sparse_pred_ignore_index:
|
|
198
|
-
reference_compile:
|
|
199
|
-
|
|
192
|
+
classifier_dropout: float | None = 0.0,
|
|
193
|
+
classifier_bias: bool | None = False,
|
|
194
|
+
classifier_activation: str | None = "gelu",
|
|
195
|
+
deterministic_flash_attn: bool | None = False,
|
|
196
|
+
sparse_prediction: bool | None = False,
|
|
197
|
+
sparse_pred_ignore_index: int | None = -100,
|
|
198
|
+
reference_compile: bool | None = None, # Deprecated
|
|
199
|
+
tie_word_embeddings: bool | None = True,
|
|
200
200
|
**kwargs,
|
|
201
201
|
):
|
|
202
|
+
self.pad_token_id = pad_token_id
|
|
203
|
+
self.bos_token_id = bos_token_id
|
|
204
|
+
self.eos_token_id = eos_token_id
|
|
205
|
+
self.cls_token_id = cls_token_id
|
|
206
|
+
self.sep_token_id = sep_token_id
|
|
207
|
+
self.tie_word_embeddings = tie_word_embeddings
|
|
202
208
|
self.vocab_size = vocab_size
|
|
203
209
|
self.max_position_embeddings = max_position_embeddings
|
|
204
210
|
self.hidden_size = hidden_size
|
|
@@ -225,7 +231,6 @@ class ModernBertConfig(PreTrainedConfig):
|
|
|
225
231
|
self.sparse_prediction = sparse_prediction
|
|
226
232
|
self.sparse_pred_ignore_index = sparse_pred_ignore_index
|
|
227
233
|
self.reference_compile = reference_compile
|
|
228
|
-
self.repad_logits_with_grad = repad_logits_with_grad
|
|
229
234
|
|
|
230
235
|
if self.classifier_pooling not in ["cls", "mean"]:
|
|
231
236
|
raise ValueError(
|
|
@@ -245,14 +250,7 @@ class ModernBertConfig(PreTrainedConfig):
|
|
|
245
250
|
layer_type_validation(self.layer_types, self.num_hidden_layers)
|
|
246
251
|
|
|
247
252
|
self.rope_parameters = rope_parameters
|
|
248
|
-
super().__init__(
|
|
249
|
-
pad_token_id=pad_token_id,
|
|
250
|
-
bos_token_id=bos_token_id,
|
|
251
|
-
eos_token_id=eos_token_id,
|
|
252
|
-
cls_token_id=cls_token_id,
|
|
253
|
-
sep_token_id=sep_token_id,
|
|
254
|
-
**kwargs,
|
|
255
|
-
)
|
|
253
|
+
super().__init__(**kwargs)
|
|
256
254
|
|
|
257
255
|
def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation=None, **kwargs):
|
|
258
256
|
rope_scaling = kwargs.pop("rope_scaling", None)
|
|
@@ -267,9 +265,15 @@ class ModernBertConfig(PreTrainedConfig):
|
|
|
267
265
|
if rope_scaling is not None:
|
|
268
266
|
self.rope_parameters["full_attention"].update(rope_scaling)
|
|
269
267
|
self.rope_parameters["sliding_attention"].update(rope_scaling)
|
|
268
|
+
|
|
269
|
+
# Set default values if not present
|
|
270
|
+
if self.rope_parameters.get("full_attention") is None:
|
|
271
|
+
self.rope_parameters["full_attention"] = {"rope_type": "default"}
|
|
270
272
|
self.rope_parameters["full_attention"].setdefault(
|
|
271
273
|
"rope_theta", kwargs.pop("global_rope_theta", self.default_theta["global"])
|
|
272
274
|
)
|
|
275
|
+
if self.rope_parameters.get("sliding_attention") is None:
|
|
276
|
+
self.rope_parameters["sliding_attention"] = {"rope_type": "default"}
|
|
273
277
|
self.rope_parameters["sliding_attention"].setdefault(
|
|
274
278
|
"rope_theta", kwargs.pop("local_rope_theta", self.default_theta["local"])
|
|
275
279
|
)
|
|
@@ -284,211 +288,15 @@ class ModernBertConfig(PreTrainedConfig):
|
|
|
284
288
|
output.pop("reference_compile", None)
|
|
285
289
|
return output
|
|
286
290
|
|
|
291
|
+
@property
|
|
292
|
+
def sliding_window(self):
|
|
293
|
+
"""Half-window size: `local_attention` is the total window, so we divide by 2."""
|
|
294
|
+
return self.local_attention // 2
|
|
287
295
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
labels: Optional[torch.Tensor] = None,
|
|
293
|
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
294
|
-
"""
|
|
295
|
-
Remove padding from input sequences.
|
|
296
|
-
|
|
297
|
-
Args:
|
|
298
|
-
inputs: (batch, seqlen, ...) or (batch, seqlen)
|
|
299
|
-
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
|
300
|
-
position_ids: (batch, seqlen), int, position ids
|
|
301
|
-
labels: (batch, seqlen), int, labels
|
|
302
|
-
|
|
303
|
-
Returns:
|
|
304
|
-
unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
|
|
305
|
-
indices: (total_nnz)
|
|
306
|
-
cu_seqlens: (batch + 1), the cumulative sequence lengths
|
|
307
|
-
max_seqlen_in_batch: int
|
|
308
|
-
unpadded_position_ids: (total_nnz) or None
|
|
309
|
-
unpadded_labels: (total_nnz) or None
|
|
310
|
-
"""
|
|
311
|
-
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
312
|
-
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
313
|
-
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
|
|
314
|
-
cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
|
315
|
-
|
|
316
|
-
if inputs.dim() == 2:
|
|
317
|
-
unpadded_inputs = inputs.flatten()[indices]
|
|
318
|
-
else:
|
|
319
|
-
batch, seqlen, *rest = inputs.shape
|
|
320
|
-
shape = batch * seqlen
|
|
321
|
-
unpadded_inputs = inputs.view(shape, *rest)[indices]
|
|
322
|
-
|
|
323
|
-
unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
|
|
324
|
-
unpadded_labels = labels.flatten()[indices] if labels is not None else None
|
|
325
|
-
|
|
326
|
-
return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
def _pad_modernbert_output(
|
|
330
|
-
inputs: torch.Tensor,
|
|
331
|
-
indices: torch.Tensor,
|
|
332
|
-
batch: int,
|
|
333
|
-
seqlen: int,
|
|
334
|
-
) -> torch.Tensor:
|
|
335
|
-
"""
|
|
336
|
-
Add padding to sequences.
|
|
337
|
-
|
|
338
|
-
Args:
|
|
339
|
-
inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
|
|
340
|
-
indices: (total_nnz)
|
|
341
|
-
batch: int, batch size
|
|
342
|
-
seqlen: int, max sequence length
|
|
343
|
-
|
|
344
|
-
Returns:
|
|
345
|
-
padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
|
|
346
|
-
"""
|
|
347
|
-
if inputs.dim() == 1:
|
|
348
|
-
output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
|
|
349
|
-
output[indices] = inputs
|
|
350
|
-
padded_inputs = output.view(batch, seqlen)
|
|
351
|
-
else:
|
|
352
|
-
_, *rest = inputs.shape
|
|
353
|
-
output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
|
|
354
|
-
output[indices] = inputs
|
|
355
|
-
padded_inputs = output.view(batch, seqlen, *rest)
|
|
356
|
-
|
|
357
|
-
return padded_inputs
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
class ApplyRotaryEmbUnpad(torch.autograd.Function):
|
|
361
|
-
@staticmethod
|
|
362
|
-
def forward(
|
|
363
|
-
ctx,
|
|
364
|
-
qkv,
|
|
365
|
-
cos,
|
|
366
|
-
sin,
|
|
367
|
-
cu_seqlens: Optional[torch.Tensor] = None,
|
|
368
|
-
max_seqlen: Optional[int] = None,
|
|
369
|
-
):
|
|
370
|
-
# (total_nnz, 3, nheads, headdim)
|
|
371
|
-
qkv = qkv.contiguous()
|
|
372
|
-
total_nnz, _three, _nheads, headdim = qkv.shape
|
|
373
|
-
# We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
|
|
374
|
-
# we get the same tensor
|
|
375
|
-
# qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
|
|
376
|
-
qk = qkv[:, :2].view(total_nnz, -1, headdim)
|
|
377
|
-
apply_rotary(
|
|
378
|
-
qk,
|
|
379
|
-
cos,
|
|
380
|
-
sin,
|
|
381
|
-
seqlen_offsets=0,
|
|
382
|
-
cu_seqlens=cu_seqlens,
|
|
383
|
-
max_seqlen=max_seqlen,
|
|
384
|
-
interleaved=False,
|
|
385
|
-
inplace=True,
|
|
386
|
-
)
|
|
387
|
-
|
|
388
|
-
ctx.save_for_backward(cos, sin, cu_seqlens)
|
|
389
|
-
ctx.max_seqlen = max_seqlen
|
|
390
|
-
return qkv
|
|
391
|
-
|
|
392
|
-
@staticmethod
|
|
393
|
-
def backward(ctx, do):
|
|
394
|
-
cos, sin, cu_seqlens = ctx.saved_tensors
|
|
395
|
-
do = do.contiguous()
|
|
396
|
-
total_nnz, _three, _nheads, headdim = do.shape
|
|
397
|
-
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
|
|
398
|
-
# we get the same tensor
|
|
399
|
-
dqk = do[:, :2].view(total_nnz, -1, headdim)
|
|
400
|
-
apply_rotary(
|
|
401
|
-
dqk,
|
|
402
|
-
cos,
|
|
403
|
-
sin,
|
|
404
|
-
seqlen_offsets=0,
|
|
405
|
-
cu_seqlens=cu_seqlens,
|
|
406
|
-
max_seqlen=ctx.max_seqlen,
|
|
407
|
-
interleaved=False,
|
|
408
|
-
inplace=True,
|
|
409
|
-
conjugate=True,
|
|
410
|
-
)
|
|
411
|
-
|
|
412
|
-
return do, None, None, None, None, None, None
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
def apply_rotary_unpadded(
|
|
416
|
-
qkv,
|
|
417
|
-
cos,
|
|
418
|
-
sin,
|
|
419
|
-
cu_seqlens: Optional[torch.Tensor] = None,
|
|
420
|
-
max_seqlen: Optional[int] = None,
|
|
421
|
-
):
|
|
422
|
-
"""
|
|
423
|
-
Arguments:
|
|
424
|
-
qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
|
|
425
|
-
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
|
426
|
-
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
|
427
|
-
of 1st half and 2nd half (GPT-NeoX style).
|
|
428
|
-
inplace: if True, apply rotary embedding in-place.
|
|
429
|
-
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
|
430
|
-
Most commonly used in inference when we have KV cache.
|
|
431
|
-
cu_seqlens: (batch + 1,) or None
|
|
432
|
-
max_seqlen: int
|
|
433
|
-
Return:
|
|
434
|
-
out: (total_nnz, dim)
|
|
435
|
-
rotary_dim must be <= headdim
|
|
436
|
-
Apply rotary embedding to the first rotary_dim of x.
|
|
437
|
-
"""
|
|
438
|
-
return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
|
|
442
|
-
"""
|
|
443
|
-
The rotary position embeddings applied directly to unpadded sequences.
|
|
444
|
-
"""
|
|
445
|
-
|
|
446
|
-
def __init__(
|
|
447
|
-
self,
|
|
448
|
-
dim: int,
|
|
449
|
-
base: float = 10000.0,
|
|
450
|
-
max_seqlen: Optional[int] = None,
|
|
451
|
-
device: Optional[torch.device] = None,
|
|
452
|
-
dtype: Optional[torch.dtype] = None,
|
|
453
|
-
):
|
|
454
|
-
"""
|
|
455
|
-
max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
|
|
456
|
-
up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
|
|
457
|
-
the cos_sin_cache will be recomputed during the forward pass.
|
|
458
|
-
"""
|
|
459
|
-
super().__init__(dim=dim, base=base, device=device, interleaved=False)
|
|
460
|
-
self.max_seqlen = max_seqlen
|
|
461
|
-
|
|
462
|
-
if max_seqlen is not None and device is not None and dtype is not None:
|
|
463
|
-
self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
|
|
464
|
-
|
|
465
|
-
def forward(
|
|
466
|
-
self,
|
|
467
|
-
qkv: torch.Tensor,
|
|
468
|
-
cu_seqlens: torch.Tensor,
|
|
469
|
-
max_seqlen: Optional[int] = None,
|
|
470
|
-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
471
|
-
"""
|
|
472
|
-
Apply rotary embedding *inplace* to qkv.
|
|
473
|
-
qkv: (total_nnz, 3, nheads, headdim)
|
|
474
|
-
cu_seqlens: (batch + 1,) cumulative sequence lengths
|
|
475
|
-
max_seqlen: int max seq length in the batch
|
|
476
|
-
"""
|
|
477
|
-
if max_seqlen is not None:
|
|
478
|
-
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
|
479
|
-
|
|
480
|
-
qkv = apply_rotary_unpadded(
|
|
481
|
-
qkv,
|
|
482
|
-
self._cos_cached,
|
|
483
|
-
self._sin_cached,
|
|
484
|
-
cu_seqlens=cu_seqlens,
|
|
485
|
-
max_seqlen=max_seqlen,
|
|
486
|
-
)
|
|
487
|
-
|
|
488
|
-
return qkv
|
|
489
|
-
|
|
490
|
-
def extra_repr(self) -> str:
|
|
491
|
-
return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
|
|
296
|
+
@sliding_window.setter
|
|
297
|
+
def sliding_window(self, value):
|
|
298
|
+
"""Set sliding_window by updating local_attention to 2 * value."""
|
|
299
|
+
self.local_attention = value * 2
|
|
492
300
|
|
|
493
301
|
|
|
494
302
|
class ModernBertEmbeddings(nn.Module):
|
|
@@ -503,21 +311,13 @@ class ModernBertEmbeddings(nn.Module):
|
|
|
503
311
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
504
312
|
self.drop = nn.Dropout(config.embedding_dropout)
|
|
505
313
|
|
|
506
|
-
@torch.compile(dynamic=True)
|
|
507
|
-
def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
|
508
|
-
return self.drop(self.norm(self.tok_embeddings(input_ids)))
|
|
509
|
-
|
|
510
314
|
def forward(
|
|
511
|
-
self, input_ids:
|
|
315
|
+
self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.Tensor | None = None
|
|
512
316
|
) -> torch.Tensor:
|
|
513
317
|
if inputs_embeds is not None:
|
|
514
318
|
hidden_states = self.drop(self.norm(inputs_embeds))
|
|
515
319
|
else:
|
|
516
|
-
hidden_states = (
|
|
517
|
-
self.compiled_embeddings(input_ids)
|
|
518
|
-
if self.config.reference_compile
|
|
519
|
-
else self.drop(self.norm(self.tok_embeddings(input_ids)))
|
|
520
|
-
)
|
|
320
|
+
hidden_states = self.drop(self.norm(self.tok_embeddings(input_ids)))
|
|
521
321
|
return hidden_states
|
|
522
322
|
|
|
523
323
|
|
|
@@ -547,138 +347,42 @@ class ModernBertRotaryEmbedding(Gemma3RotaryEmbedding):
|
|
|
547
347
|
|
|
548
348
|
@staticmethod
|
|
549
349
|
def compute_default_rope_parameters(
|
|
550
|
-
config:
|
|
350
|
+
config: ModernBertConfig | None = None,
|
|
551
351
|
device: Optional["torch.device"] = None,
|
|
552
|
-
seq_len:
|
|
553
|
-
layer_type:
|
|
352
|
+
seq_len: int | None = None,
|
|
353
|
+
layer_type: str | None = None,
|
|
554
354
|
) -> tuple["torch.Tensor", float]:
|
|
555
355
|
return super().compute_default_rope_parameters(config, device, seq_len, layer_type)
|
|
556
356
|
|
|
557
357
|
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
attention_mask: torch.Tensor,
|
|
562
|
-
sliding_window_mask: torch.Tensor,
|
|
563
|
-
position_ids: Optional[torch.LongTensor],
|
|
564
|
-
local_attention: tuple[int, int],
|
|
565
|
-
bs: int,
|
|
566
|
-
dim: int,
|
|
567
|
-
position_embeddings: torch.Tensor,
|
|
568
|
-
output_attentions: Optional[bool] = False,
|
|
569
|
-
**_kwargs,
|
|
570
|
-
) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
|
|
571
|
-
# qkv: [batch_size, seqlen, 3, nheads, headdim]
|
|
572
|
-
cos, sin = position_embeddings
|
|
573
|
-
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
|
|
574
|
-
# query, key, value: [batch_size, heads, seq_len, head_dim]
|
|
575
|
-
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
|
576
|
-
|
|
577
|
-
scale = module.head_dim**-0.5
|
|
578
|
-
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
|
|
579
|
-
|
|
580
|
-
if local_attention != (-1, -1):
|
|
581
|
-
attention_mask = sliding_window_mask
|
|
582
|
-
|
|
583
|
-
attn_weights = attn_weights + attention_mask
|
|
584
|
-
|
|
585
|
-
# upcast attention to fp32
|
|
586
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
|
587
|
-
attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
|
|
588
|
-
attn_output = torch.matmul(attn_weights, value)
|
|
589
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
590
|
-
attn_output = attn_output.view(bs, -1, dim)
|
|
591
|
-
if output_attentions:
|
|
592
|
-
return (attn_output, attn_weights)
|
|
593
|
-
return (attn_output,)
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
def flash_attention_forward(
|
|
597
|
-
module: "ModernBertAttention",
|
|
598
|
-
qkv: torch.Tensor,
|
|
599
|
-
rotary_emb: ModernBertUnpaddedRotaryEmbedding,
|
|
600
|
-
cu_seqlens: torch.Tensor,
|
|
601
|
-
max_seqlen: int,
|
|
602
|
-
local_attention: tuple[int, int],
|
|
603
|
-
bs: int,
|
|
604
|
-
dim: int,
|
|
605
|
-
target_dtype: torch.dtype = torch.bfloat16,
|
|
606
|
-
**_kwargs,
|
|
607
|
-
) -> tuple[torch.Tensor]:
|
|
608
|
-
# (total_seqlen, 3, nheads, headdim)
|
|
609
|
-
qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
|
610
|
-
|
|
611
|
-
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
|
|
612
|
-
if convert_dtype:
|
|
613
|
-
# FA2 implementation only supports fp16 and bf16. If FA2 is supported,
|
|
614
|
-
# bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
|
|
615
|
-
orig_dtype = qkv.dtype
|
|
616
|
-
qkv = qkv.to(target_dtype)
|
|
617
|
-
|
|
618
|
-
attn = flash_attn_varlen_qkvpacked_func(
|
|
619
|
-
qkv,
|
|
620
|
-
cu_seqlens=cu_seqlens,
|
|
621
|
-
max_seqlen=max_seqlen,
|
|
622
|
-
dropout_p=module.attention_dropout if module.training else 0.0,
|
|
623
|
-
deterministic=module.deterministic_flash_attn,
|
|
624
|
-
window_size=local_attention,
|
|
625
|
-
)
|
|
626
|
-
attn = attn.to(orig_dtype) # type: ignore
|
|
627
|
-
else:
|
|
628
|
-
attn = flash_attn_varlen_qkvpacked_func(
|
|
629
|
-
qkv,
|
|
630
|
-
cu_seqlens=cu_seqlens,
|
|
631
|
-
max_seqlen=max_seqlen,
|
|
632
|
-
dropout_p=module.attention_dropout if module.training else 0.0,
|
|
633
|
-
deterministic=module.deterministic_flash_attn,
|
|
634
|
-
window_size=local_attention,
|
|
635
|
-
)
|
|
636
|
-
return (attn.view(bs, dim),)
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
def sdpa_attention_forward(
|
|
640
|
-
module: "ModernBertAttention",
|
|
641
|
-
qkv: torch.Tensor,
|
|
642
|
-
attention_mask: torch.Tensor,
|
|
643
|
-
sliding_window_mask: torch.Tensor,
|
|
644
|
-
position_ids: Optional[torch.LongTensor],
|
|
645
|
-
local_attention: tuple[int, int],
|
|
646
|
-
bs: int,
|
|
647
|
-
dim: int,
|
|
648
|
-
position_embeddings: torch.Tensor,
|
|
649
|
-
**_kwargs,
|
|
650
|
-
) -> tuple[torch.Tensor]:
|
|
651
|
-
# qkv: [batch_size, seqlen, 3, nheads, headdim]
|
|
652
|
-
cos, sin = position_embeddings
|
|
653
|
-
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
|
|
654
|
-
# query, key, value: [batch_size, heads, seq_len, head_dim]
|
|
655
|
-
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
|
656
|
-
|
|
657
|
-
if local_attention != (-1, -1):
|
|
658
|
-
attention_mask = sliding_window_mask
|
|
659
|
-
|
|
660
|
-
attn_output = (
|
|
661
|
-
F.scaled_dot_product_attention(
|
|
662
|
-
query,
|
|
663
|
-
key,
|
|
664
|
-
value,
|
|
665
|
-
dropout_p=module.attention_dropout if module.training else 0.0,
|
|
666
|
-
attn_mask=attention_mask,
|
|
667
|
-
)
|
|
668
|
-
.transpose(1, 2)
|
|
669
|
-
.contiguous()
|
|
670
|
-
)
|
|
671
|
-
attn_output = attn_output.view(bs, -1, dim)
|
|
672
|
-
return (attn_output,)
|
|
673
|
-
|
|
358
|
+
@use_kernel_func_from_hub("rotary_pos_emb")
|
|
359
|
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
|
360
|
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
674
361
|
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
362
|
+
Args:
|
|
363
|
+
q (`torch.Tensor`): The query tensor.
|
|
364
|
+
k (`torch.Tensor`): The key tensor.
|
|
365
|
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
366
|
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
367
|
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
368
|
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
369
|
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
370
|
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
371
|
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
372
|
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
373
|
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
374
|
+
Returns:
|
|
375
|
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
376
|
+
"""
|
|
377
|
+
original_dtype = q.dtype
|
|
378
|
+
cos = cos.unsqueeze(unsqueeze_dim)
|
|
379
|
+
sin = sin.unsqueeze(unsqueeze_dim)
|
|
380
|
+
q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
|
|
381
|
+
k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
|
|
382
|
+
return q_embed.to(original_dtype), k_embed.to(original_dtype)
|
|
680
383
|
|
|
681
384
|
|
|
385
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
682
386
|
class ModernBertAttention(nn.Module):
|
|
683
387
|
"""Performs multi-headed self attention on a batch of unpadded sequences.
|
|
684
388
|
|
|
@@ -689,10 +393,10 @@ class ModernBertAttention(nn.Module):
|
|
|
689
393
|
See `forward` method for additional details.
|
|
690
394
|
"""
|
|
691
395
|
|
|
692
|
-
def __init__(self, config: ModernBertConfig,
|
|
396
|
+
def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
|
|
693
397
|
super().__init__()
|
|
694
398
|
self.config = config
|
|
695
|
-
self.
|
|
399
|
+
self.layer_idx = layer_idx
|
|
696
400
|
|
|
697
401
|
if config.hidden_size % config.num_attention_heads != 0:
|
|
698
402
|
raise ValueError(
|
|
@@ -701,29 +405,19 @@ class ModernBertAttention(nn.Module):
|
|
|
701
405
|
|
|
702
406
|
self.attention_dropout = config.attention_dropout
|
|
703
407
|
self.deterministic_flash_attn = config.deterministic_flash_attn
|
|
704
|
-
self.num_heads = config.num_attention_heads
|
|
705
408
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
706
|
-
self.
|
|
707
|
-
|
|
708
|
-
|
|
409
|
+
self.Wqkv = nn.Linear(
|
|
410
|
+
config.hidden_size, 3 * self.head_dim * config.num_attention_heads, bias=config.attention_bias
|
|
411
|
+
)
|
|
709
412
|
|
|
710
|
-
if
|
|
711
|
-
|
|
712
|
-
|
|
413
|
+
if config.layer_types[layer_idx] == "sliding_attention":
|
|
414
|
+
# config.sliding_window = local_attention // 2 (half-window size, e.g. 64 for local_attention=128)
|
|
415
|
+
# +1 is needed because flash attention sets inclusive boundaries (see modeling_flash_attention_utils.py)
|
|
416
|
+
self.sliding_window = config.sliding_window + 1
|
|
713
417
|
else:
|
|
714
|
-
self.
|
|
715
|
-
max_position_embeddings = config.max_position_embeddings
|
|
418
|
+
self.sliding_window = None
|
|
716
419
|
|
|
717
|
-
|
|
718
|
-
rope_parameters_dict = (
|
|
719
|
-
self.config.rope_parameters[layer_type] if layer_type is not None else self.config.rope_parameters
|
|
720
|
-
)
|
|
721
|
-
rope_theta = rope_parameters_dict["rope_theta"]
|
|
722
|
-
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
|
|
723
|
-
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
|
|
724
|
-
)
|
|
725
|
-
else:
|
|
726
|
-
self.rotary_emb = None
|
|
420
|
+
self.is_causal = False
|
|
727
421
|
|
|
728
422
|
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
|
729
423
|
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
|
@@ -731,82 +425,75 @@ class ModernBertAttention(nn.Module):
|
|
|
731
425
|
def forward(
|
|
732
426
|
self,
|
|
733
427
|
hidden_states: torch.Tensor,
|
|
734
|
-
position_embeddings:
|
|
735
|
-
|
|
736
|
-
**kwargs,
|
|
737
|
-
) -> torch.Tensor:
|
|
428
|
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
429
|
+
attention_mask: torch.Tensor | None = None,
|
|
430
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
431
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
432
|
+
input_shape = hidden_states.shape[:-1]
|
|
433
|
+
|
|
738
434
|
qkv = self.Wqkv(hidden_states)
|
|
435
|
+
qkv = qkv.view(*input_shape, 3, -1, self.head_dim)
|
|
436
|
+
query_states, key_states, value_states = qkv.unbind(dim=-3)
|
|
739
437
|
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
438
|
+
query_states = query_states.transpose(1, 2)
|
|
439
|
+
key_states = key_states.transpose(1, 2)
|
|
440
|
+
value_states = value_states.transpose(1, 2)
|
|
441
|
+
|
|
442
|
+
cos, sin = position_embeddings
|
|
443
|
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)
|
|
444
|
+
|
|
445
|
+
attention_interface = eager_attention_forward
|
|
446
|
+
if self.config._attn_implementation != "eager":
|
|
447
|
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
745
448
|
|
|
746
|
-
|
|
449
|
+
attn_output, attn_weights = attention_interface(
|
|
747
450
|
self,
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
451
|
+
query_states,
|
|
452
|
+
key_states,
|
|
453
|
+
value_states,
|
|
454
|
+
attention_mask,
|
|
455
|
+
dropout=self.attention_dropout if self.training else 0.0,
|
|
456
|
+
scaling=self.head_dim**-0.5,
|
|
457
|
+
sliding_window=self.sliding_window,
|
|
458
|
+
deterministic=self.deterministic_flash_attn,
|
|
755
459
|
**kwargs,
|
|
756
460
|
)
|
|
757
|
-
hidden_states = attn_outputs[0]
|
|
758
|
-
hidden_states = self.out_drop(self.Wo(hidden_states))
|
|
759
461
|
|
|
760
|
-
|
|
462
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
463
|
+
attn_output = self.out_drop(self.Wo(attn_output))
|
|
464
|
+
return attn_output, attn_weights
|
|
761
465
|
|
|
762
466
|
|
|
763
467
|
class ModernBertEncoderLayer(GradientCheckpointingLayer):
|
|
764
|
-
def __init__(self, config: ModernBertConfig,
|
|
468
|
+
def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
|
|
765
469
|
super().__init__()
|
|
766
470
|
self.config = config
|
|
767
|
-
|
|
471
|
+
self.layer_idx = layer_idx
|
|
472
|
+
if layer_idx == 0:
|
|
768
473
|
self.attn_norm = nn.Identity()
|
|
769
474
|
else:
|
|
770
475
|
self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
771
|
-
self.attn = ModernBertAttention(config=config,
|
|
476
|
+
self.attn = ModernBertAttention(config=config, layer_idx=layer_idx)
|
|
772
477
|
self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
773
478
|
self.mlp = ModernBertMLP(config)
|
|
774
|
-
self.attention_type = config.layer_types[
|
|
775
|
-
|
|
776
|
-
@torch.compile(dynamic=True)
|
|
777
|
-
def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
778
|
-
return self.mlp(self.mlp_norm(hidden_states))
|
|
479
|
+
self.attention_type = config.layer_types[layer_idx]
|
|
779
480
|
|
|
780
481
|
def forward(
|
|
781
482
|
self,
|
|
782
483
|
hidden_states: torch.Tensor,
|
|
783
|
-
attention_mask:
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
cu_seqlens: Optional[torch.Tensor] = None,
|
|
787
|
-
max_seqlen: Optional[int] = None,
|
|
788
|
-
position_embeddings: Optional[torch.Tensor] = None,
|
|
789
|
-
output_attentions: Optional[bool] = False,
|
|
484
|
+
attention_mask: torch.Tensor | None = None,
|
|
485
|
+
position_embeddings: torch.Tensor | None = None,
|
|
486
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
790
487
|
) -> torch.Tensor:
|
|
791
|
-
|
|
488
|
+
attn_output, _ = self.attn(
|
|
792
489
|
self.attn_norm(hidden_states),
|
|
793
|
-
attention_mask=attention_mask,
|
|
794
|
-
sliding_window_mask=sliding_window_mask,
|
|
795
|
-
position_ids=position_ids,
|
|
796
|
-
cu_seqlens=cu_seqlens,
|
|
797
|
-
max_seqlen=max_seqlen,
|
|
798
490
|
position_embeddings=position_embeddings,
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
hidden_states = hidden_states + attn_outputs[0]
|
|
802
|
-
mlp_output = (
|
|
803
|
-
self.compiled_mlp(hidden_states)
|
|
804
|
-
if self.config.reference_compile
|
|
805
|
-
else self.mlp(self.mlp_norm(hidden_states))
|
|
491
|
+
attention_mask=attention_mask,
|
|
492
|
+
**kwargs,
|
|
806
493
|
)
|
|
807
|
-
hidden_states = hidden_states +
|
|
808
|
-
|
|
809
|
-
return
|
|
494
|
+
hidden_states = hidden_states + attn_output
|
|
495
|
+
hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states))
|
|
496
|
+
return hidden_states
|
|
810
497
|
|
|
811
498
|
|
|
812
499
|
@auto_docstring
|
|
@@ -817,7 +504,13 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|
|
817
504
|
_no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
|
|
818
505
|
_supports_flash_attn = True
|
|
819
506
|
_supports_sdpa = True
|
|
820
|
-
_supports_flex_attn =
|
|
507
|
+
_supports_flex_attn = True
|
|
508
|
+
_supports_attention_backend = True
|
|
509
|
+
|
|
510
|
+
_can_record_outputs = {
|
|
511
|
+
"hidden_states": ModernBertEncoderLayer,
|
|
512
|
+
"attentions": ModernBertAttention,
|
|
513
|
+
}
|
|
821
514
|
|
|
822
515
|
@torch.no_grad()
|
|
823
516
|
def _init_weights(self, module: nn.Module):
|
|
@@ -879,75 +572,24 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|
|
879
572
|
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
880
573
|
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
881
574
|
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
882
|
-
elif isinstance(module, ModernBertUnpaddedRotaryEmbedding):
|
|
883
|
-
inv_freq = module._compute_inv_freq()
|
|
884
|
-
init.copy_(module.inv_freq, inv_freq)
|
|
885
575
|
|
|
886
576
|
def _check_and_adjust_attn_implementation(
|
|
887
|
-
self, attn_implementation:
|
|
577
|
+
self, attn_implementation: str | None, is_init_check: bool = False
|
|
888
578
|
) -> str:
|
|
889
579
|
"""
|
|
890
580
|
Checks and dispatches to hhe requested attention implementation.
|
|
891
581
|
"""
|
|
892
|
-
# If the user didn't specify anything, try to use flash_attention_2
|
|
582
|
+
# If the user didn't specify anything, try to use flash_attention_2.
|
|
893
583
|
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
|
|
894
|
-
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
|
895
|
-
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
|
896
|
-
|
|
897
584
|
try:
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
else attn_implementation
|
|
585
|
+
requested_attn_implementation = "flash_attention_2" if attn_implementation is None else attn_implementation
|
|
586
|
+
return super()._check_and_adjust_attn_implementation(
|
|
587
|
+
attn_implementation=requested_attn_implementation, is_init_check=is_init_check
|
|
902
588
|
)
|
|
903
589
|
except (ValueError, ImportError):
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
)
|
|
908
|
-
|
|
909
|
-
def _maybe_set_compile(self):
|
|
910
|
-
if self.config.reference_compile is False:
|
|
911
|
-
return
|
|
912
|
-
|
|
913
|
-
if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
|
|
914
|
-
if self.config.reference_compile:
|
|
915
|
-
logger.warning_once(
|
|
916
|
-
"If `accelerate` split the model across devices, `torch.compile` will not work. "
|
|
917
|
-
"Falling back to non-compiled mode."
|
|
918
|
-
)
|
|
919
|
-
self.config.reference_compile = False
|
|
920
|
-
|
|
921
|
-
if self.device.type == "mps":
|
|
922
|
-
if self.config.reference_compile:
|
|
923
|
-
logger.warning_once(
|
|
924
|
-
"Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
|
|
925
|
-
"Falling back to non-compiled mode."
|
|
926
|
-
)
|
|
927
|
-
self.config.reference_compile = False
|
|
928
|
-
|
|
929
|
-
if self.device.type == "cpu":
|
|
930
|
-
if self.config.reference_compile:
|
|
931
|
-
logger.warning_once(
|
|
932
|
-
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
|
|
933
|
-
"Falling back to non-compiled mode."
|
|
934
|
-
)
|
|
935
|
-
self.config.reference_compile = False
|
|
936
|
-
|
|
937
|
-
if self.config.reference_compile is None:
|
|
938
|
-
self.config.reference_compile = is_triton_available()
|
|
939
|
-
|
|
940
|
-
def resize_token_embeddings(self, *args, **kwargs):
|
|
941
|
-
model_embeds = super().resize_token_embeddings(*args, **kwargs)
|
|
942
|
-
|
|
943
|
-
if self.config.reference_compile in {True, None}:
|
|
944
|
-
if self.config.reference_compile:
|
|
945
|
-
logger.warning_once(
|
|
946
|
-
"Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
|
|
947
|
-
)
|
|
948
|
-
self.config.reference_compile = False
|
|
949
|
-
|
|
950
|
-
return model_embeds
|
|
590
|
+
return super()._check_and_adjust_attn_implementation(
|
|
591
|
+
attn_implementation=attn_implementation, is_init_check=is_init_check
|
|
592
|
+
)
|
|
951
593
|
|
|
952
594
|
|
|
953
595
|
@auto_docstring
|
|
@@ -957,7 +599,7 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
|
|
957
599
|
self.config = config
|
|
958
600
|
self.embeddings = ModernBertEmbeddings(config)
|
|
959
601
|
self.layers = nn.ModuleList(
|
|
960
|
-
[ModernBertEncoderLayer(config,
|
|
602
|
+
[ModernBertEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
961
603
|
)
|
|
962
604
|
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
963
605
|
self.rotary_emb = ModernBertRotaryEmbedding(config=config)
|
|
@@ -970,175 +612,53 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
|
|
970
612
|
def set_input_embeddings(self, value):
|
|
971
613
|
self.embeddings.tok_embeddings = value
|
|
972
614
|
|
|
615
|
+
@check_model_inputs
|
|
973
616
|
@auto_docstring
|
|
974
617
|
def forward(
|
|
975
618
|
self,
|
|
976
|
-
input_ids:
|
|
977
|
-
attention_mask:
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
cu_seqlens: Optional[torch.Tensor] = None,
|
|
983
|
-
max_seqlen: Optional[int] = None,
|
|
984
|
-
batch_size: Optional[int] = None,
|
|
985
|
-
seq_len: Optional[int] = None,
|
|
986
|
-
output_attentions: Optional[bool] = None,
|
|
987
|
-
output_hidden_states: Optional[bool] = None,
|
|
988
|
-
return_dict: Optional[bool] = None,
|
|
989
|
-
**kwargs,
|
|
990
|
-
) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
|
|
991
|
-
r"""
|
|
992
|
-
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
993
|
-
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
994
|
-
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
995
|
-
far-away tokens in the local attention layers when not using Flash Attention.
|
|
996
|
-
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
997
|
-
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
998
|
-
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
999
|
-
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
1000
|
-
max_seqlen (`int`, *optional*):
|
|
1001
|
-
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
1002
|
-
batch_size (`int`, *optional*):
|
|
1003
|
-
Batch size of the input sequences. Used to pad the output tensors.
|
|
1004
|
-
seq_len (`int`, *optional*):
|
|
1005
|
-
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
1006
|
-
"""
|
|
1007
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1008
|
-
output_hidden_states = (
|
|
1009
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1010
|
-
)
|
|
1011
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1012
|
-
|
|
619
|
+
input_ids: torch.LongTensor | None = None,
|
|
620
|
+
attention_mask: torch.Tensor | None = None,
|
|
621
|
+
position_ids: torch.LongTensor | None = None,
|
|
622
|
+
inputs_embeds: torch.Tensor | None = None,
|
|
623
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
624
|
+
) -> BaseModelOutput:
|
|
1013
625
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
1014
626
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
1015
627
|
|
|
1016
|
-
|
|
1017
|
-
all_self_attentions = () if output_attentions else None
|
|
1018
|
-
|
|
1019
|
-
self._maybe_set_compile()
|
|
1020
|
-
|
|
1021
|
-
if input_ids is not None:
|
|
1022
|
-
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
1023
|
-
|
|
1024
|
-
if batch_size is None and seq_len is None:
|
|
1025
|
-
if inputs_embeds is not None:
|
|
1026
|
-
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
1027
|
-
else:
|
|
1028
|
-
batch_size, seq_len = input_ids.shape[:2]
|
|
628
|
+
seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
|
|
1029
629
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
1030
630
|
|
|
1031
|
-
if
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
repad = False
|
|
1035
|
-
if self.config._attn_implementation == "flash_attention_2":
|
|
1036
|
-
if indices is None and cu_seqlens is None and max_seqlen is None:
|
|
1037
|
-
repad = True
|
|
1038
|
-
if inputs_embeds is None:
|
|
1039
|
-
with torch.no_grad():
|
|
1040
|
-
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
|
1041
|
-
inputs=input_ids, attention_mask=attention_mask
|
|
1042
|
-
)
|
|
1043
|
-
else:
|
|
1044
|
-
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
|
1045
|
-
inputs=inputs_embeds, attention_mask=attention_mask
|
|
1046
|
-
)
|
|
1047
|
-
if position_ids is None:
|
|
1048
|
-
position_ids = indices.unsqueeze(0)
|
|
1049
|
-
else:
|
|
1050
|
-
if position_ids is None:
|
|
1051
|
-
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
|
|
1052
|
-
|
|
1053
|
-
attention_mask, sliding_window_mask = self._update_attention_mask(
|
|
1054
|
-
attention_mask, output_attentions=output_attentions
|
|
1055
|
-
)
|
|
631
|
+
if position_ids is None:
|
|
632
|
+
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
|
|
1056
633
|
|
|
1057
634
|
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
|
|
635
|
+
|
|
636
|
+
if not isinstance(attention_mask_mapping := attention_mask, dict):
|
|
637
|
+
mask_kwargs = {
|
|
638
|
+
"config": self.config,
|
|
639
|
+
"input_embeds": hidden_states,
|
|
640
|
+
"attention_mask": attention_mask,
|
|
641
|
+
}
|
|
642
|
+
attention_mask_mapping = {
|
|
643
|
+
"full_attention": create_bidirectional_mask(**mask_kwargs),
|
|
644
|
+
"sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
|
|
645
|
+
}
|
|
646
|
+
|
|
1058
647
|
position_embeddings = {}
|
|
1059
648
|
for layer_type in self.config.layer_types:
|
|
1060
649
|
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
|
1061
650
|
|
|
1062
651
|
for encoder_layer in self.layers:
|
|
1063
|
-
|
|
1064
|
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
1065
|
-
|
|
1066
|
-
layer_outputs = encoder_layer(
|
|
652
|
+
hidden_states = encoder_layer(
|
|
1067
653
|
hidden_states,
|
|
1068
|
-
attention_mask=
|
|
1069
|
-
sliding_window_mask=sliding_window_mask,
|
|
1070
|
-
position_ids=position_ids,
|
|
1071
|
-
cu_seqlens=cu_seqlens,
|
|
1072
|
-
max_seqlen=max_seqlen,
|
|
654
|
+
attention_mask=attention_mask_mapping[encoder_layer.attention_type],
|
|
1073
655
|
position_embeddings=position_embeddings[encoder_layer.attention_type],
|
|
1074
|
-
|
|
656
|
+
**kwargs,
|
|
1075
657
|
)
|
|
1076
|
-
hidden_states = layer_outputs[0]
|
|
1077
|
-
if output_attentions and len(layer_outputs) > 1:
|
|
1078
|
-
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
|
1079
|
-
|
|
1080
|
-
if output_hidden_states:
|
|
1081
|
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
1082
658
|
|
|
1083
659
|
hidden_states = self.final_norm(hidden_states)
|
|
1084
660
|
|
|
1085
|
-
|
|
1086
|
-
hidden_states = _pad_modernbert_output(
|
|
1087
|
-
inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
|
|
1088
|
-
)
|
|
1089
|
-
if all_hidden_states is not None:
|
|
1090
|
-
all_hidden_states = tuple(
|
|
1091
|
-
_pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
|
|
1092
|
-
for hs in all_hidden_states
|
|
1093
|
-
)
|
|
1094
|
-
# If the attention implementation is FA2 and there is no need for repadding, there might still be the batch
|
|
1095
|
-
# dimension missing
|
|
1096
|
-
elif (
|
|
1097
|
-
self.config._attn_implementation == "flash_attention_2"
|
|
1098
|
-
and all_hidden_states is not None
|
|
1099
|
-
and all_hidden_states[-1].dim() == 2
|
|
1100
|
-
):
|
|
1101
|
-
hidden_states = hidden_states.unsqueeze(0)
|
|
1102
|
-
all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states)
|
|
1103
|
-
|
|
1104
|
-
if not return_dict:
|
|
1105
|
-
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
|
1106
|
-
return BaseModelOutput(
|
|
1107
|
-
last_hidden_state=hidden_states,
|
|
1108
|
-
hidden_states=all_hidden_states,
|
|
1109
|
-
attentions=all_self_attentions,
|
|
1110
|
-
)
|
|
1111
|
-
|
|
1112
|
-
def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
|
|
1113
|
-
if output_attentions:
|
|
1114
|
-
if self.config._attn_implementation == "sdpa":
|
|
1115
|
-
logger.warning_once(
|
|
1116
|
-
"Outputting attentions is only supported with the 'eager' attention implementation, "
|
|
1117
|
-
'not with "sdpa". Falling back to `attn_implementation="eager"`.'
|
|
1118
|
-
)
|
|
1119
|
-
self.config._attn_implementation = "eager"
|
|
1120
|
-
elif self.config._attn_implementation != "eager":
|
|
1121
|
-
logger.warning_once(
|
|
1122
|
-
"Outputting attentions is only supported with the eager attention implementation, "
|
|
1123
|
-
f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
|
|
1124
|
-
" Setting `output_attentions=False`."
|
|
1125
|
-
)
|
|
1126
|
-
|
|
1127
|
-
global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
|
|
1128
|
-
|
|
1129
|
-
# Create position indices
|
|
1130
|
-
rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
|
|
1131
|
-
# Calculate distance between positions
|
|
1132
|
-
distance = torch.abs(rows - rows.T)
|
|
1133
|
-
|
|
1134
|
-
# Create sliding window mask (1 for positions within window, 0 outside)
|
|
1135
|
-
window_mask = (
|
|
1136
|
-
(distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
|
|
1137
|
-
)
|
|
1138
|
-
# Combine with existing mask
|
|
1139
|
-
sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
|
|
1140
|
-
|
|
1141
|
-
return global_attention_mask, sliding_window_mask
|
|
661
|
+
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
1142
662
|
|
|
1143
663
|
|
|
1144
664
|
class ModernBertPredictionHead(nn.Module):
|
|
@@ -1180,84 +700,23 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
|
|
1180
700
|
def set_output_embeddings(self, new_embeddings: nn.Linear):
|
|
1181
701
|
self.decoder = new_embeddings
|
|
1182
702
|
|
|
1183
|
-
@
|
|
1184
|
-
def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
|
|
1185
|
-
return self.decoder(self.head(output))
|
|
1186
|
-
|
|
703
|
+
@can_return_tuple
|
|
1187
704
|
@auto_docstring
|
|
1188
705
|
def forward(
|
|
1189
706
|
self,
|
|
1190
|
-
input_ids:
|
|
1191
|
-
attention_mask:
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
cu_seqlens: Optional[torch.Tensor] = None,
|
|
1198
|
-
max_seqlen: Optional[int] = None,
|
|
1199
|
-
batch_size: Optional[int] = None,
|
|
1200
|
-
seq_len: Optional[int] = None,
|
|
1201
|
-
output_attentions: Optional[bool] = None,
|
|
1202
|
-
output_hidden_states: Optional[bool] = None,
|
|
1203
|
-
return_dict: Optional[bool] = None,
|
|
1204
|
-
**kwargs,
|
|
1205
|
-
) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
|
|
1206
|
-
r"""
|
|
1207
|
-
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1208
|
-
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
1209
|
-
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
1210
|
-
far-away tokens in the local attention layers when not using Flash Attention.
|
|
1211
|
-
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
1212
|
-
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
1213
|
-
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
1214
|
-
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
1215
|
-
max_seqlen (`int`, *optional*):
|
|
1216
|
-
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
1217
|
-
batch_size (`int`, *optional*):
|
|
1218
|
-
Batch size of the input sequences. Used to pad the output tensors.
|
|
1219
|
-
seq_len (`int`, *optional*):
|
|
1220
|
-
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
1221
|
-
"""
|
|
1222
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1223
|
-
self._maybe_set_compile()
|
|
1224
|
-
|
|
1225
|
-
if self.config._attn_implementation == "flash_attention_2":
|
|
1226
|
-
if indices is None and cu_seqlens is None and max_seqlen is None:
|
|
1227
|
-
if batch_size is None and seq_len is None:
|
|
1228
|
-
if inputs_embeds is not None:
|
|
1229
|
-
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
1230
|
-
else:
|
|
1231
|
-
batch_size, seq_len = input_ids.shape[:2]
|
|
1232
|
-
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
1233
|
-
|
|
1234
|
-
if attention_mask is None:
|
|
1235
|
-
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
|
|
1236
|
-
|
|
1237
|
-
if inputs_embeds is None:
|
|
1238
|
-
with torch.no_grad():
|
|
1239
|
-
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
|
1240
|
-
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
|
1241
|
-
)
|
|
1242
|
-
else:
|
|
1243
|
-
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
|
1244
|
-
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
|
1245
|
-
)
|
|
1246
|
-
|
|
707
|
+
input_ids: torch.LongTensor | None = None,
|
|
708
|
+
attention_mask: torch.Tensor | None = None,
|
|
709
|
+
position_ids: torch.Tensor | None = None,
|
|
710
|
+
inputs_embeds: torch.Tensor | None = None,
|
|
711
|
+
labels: torch.Tensor | None = None,
|
|
712
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
713
|
+
) -> tuple[torch.Tensor] | MaskedLMOutput:
|
|
1247
714
|
outputs = self.model(
|
|
1248
715
|
input_ids=input_ids,
|
|
1249
716
|
attention_mask=attention_mask,
|
|
1250
|
-
sliding_window_mask=sliding_window_mask,
|
|
1251
717
|
position_ids=position_ids,
|
|
1252
718
|
inputs_embeds=inputs_embeds,
|
|
1253
|
-
|
|
1254
|
-
cu_seqlens=cu_seqlens,
|
|
1255
|
-
max_seqlen=max_seqlen,
|
|
1256
|
-
batch_size=batch_size,
|
|
1257
|
-
seq_len=seq_len,
|
|
1258
|
-
output_attentions=output_attentions,
|
|
1259
|
-
output_hidden_states=output_hidden_states,
|
|
1260
|
-
return_dict=return_dict,
|
|
719
|
+
**kwargs,
|
|
1261
720
|
)
|
|
1262
721
|
last_hidden_state = outputs[0]
|
|
1263
722
|
|
|
@@ -1271,35 +730,12 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
|
|
1271
730
|
last_hidden_state = last_hidden_state[mask_tokens]
|
|
1272
731
|
labels = labels[mask_tokens]
|
|
1273
732
|
|
|
1274
|
-
logits = (
|
|
1275
|
-
self.compiled_head(last_hidden_state)
|
|
1276
|
-
if self.config.reference_compile
|
|
1277
|
-
else self.decoder(self.head(last_hidden_state))
|
|
1278
|
-
)
|
|
733
|
+
logits = self.decoder(self.head(last_hidden_state))
|
|
1279
734
|
|
|
1280
735
|
loss = None
|
|
1281
736
|
if labels is not None:
|
|
1282
737
|
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
|
|
1283
738
|
|
|
1284
|
-
if self.config._attn_implementation == "flash_attention_2":
|
|
1285
|
-
# Logits padding
|
|
1286
|
-
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
|
|
1287
|
-
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
|
|
1288
|
-
# Hidden states padding
|
|
1289
|
-
if getattr(outputs, "hidden_states", None) is not None:
|
|
1290
|
-
padded_hidden_states = []
|
|
1291
|
-
for hs in outputs.hidden_states:
|
|
1292
|
-
if hs.dim() == 3 and hs.shape[0] == 1:
|
|
1293
|
-
hs = hs.squeeze(0)
|
|
1294
|
-
padded_hidden_states.append(
|
|
1295
|
-
_pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
|
|
1296
|
-
)
|
|
1297
|
-
outputs.hidden_states = tuple(padded_hidden_states)
|
|
1298
|
-
|
|
1299
|
-
if not return_dict:
|
|
1300
|
-
output = (logits,)
|
|
1301
|
-
return ((loss,) + output) if loss is not None else output
|
|
1302
|
-
|
|
1303
739
|
return MaskedLMOutput(
|
|
1304
740
|
loss=loss,
|
|
1305
741
|
logits=logits,
|
|
@@ -1327,81 +763,39 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
|
|
|
1327
763
|
# Initialize weights and apply final processing
|
|
1328
764
|
self.post_init()
|
|
1329
765
|
|
|
766
|
+
@can_return_tuple
|
|
1330
767
|
@auto_docstring
|
|
1331
768
|
def forward(
|
|
1332
769
|
self,
|
|
1333
|
-
input_ids:
|
|
1334
|
-
attention_mask:
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
cu_seqlens: Optional[torch.Tensor] = None,
|
|
1341
|
-
max_seqlen: Optional[int] = None,
|
|
1342
|
-
batch_size: Optional[int] = None,
|
|
1343
|
-
seq_len: Optional[int] = None,
|
|
1344
|
-
output_attentions: Optional[bool] = None,
|
|
1345
|
-
output_hidden_states: Optional[bool] = None,
|
|
1346
|
-
return_dict: Optional[bool] = None,
|
|
1347
|
-
**kwargs,
|
|
1348
|
-
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
|
|
770
|
+
input_ids: torch.LongTensor | None = None,
|
|
771
|
+
attention_mask: torch.Tensor | None = None,
|
|
772
|
+
position_ids: torch.Tensor | None = None,
|
|
773
|
+
inputs_embeds: torch.Tensor | None = None,
|
|
774
|
+
labels: torch.Tensor | None = None,
|
|
775
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
776
|
+
) -> tuple[torch.Tensor] | SequenceClassifierOutput:
|
|
1349
777
|
r"""
|
|
1350
|
-
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1351
|
-
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
1352
|
-
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
1353
|
-
far-away tokens in the local attention layers when not using Flash Attention.
|
|
1354
778
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
1355
779
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
1356
780
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
1357
781
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
1358
|
-
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
1359
|
-
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
1360
|
-
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
1361
|
-
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
1362
|
-
max_seqlen (`int`, *optional*):
|
|
1363
|
-
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
1364
|
-
batch_size (`int`, *optional*):
|
|
1365
|
-
Batch size of the input sequences. Used to pad the output tensors.
|
|
1366
|
-
seq_len (`int`, *optional*):
|
|
1367
|
-
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
1368
782
|
"""
|
|
1369
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1370
|
-
self._maybe_set_compile()
|
|
1371
|
-
|
|
1372
|
-
if input_ids is not None:
|
|
1373
|
-
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
1374
|
-
|
|
1375
|
-
if batch_size is None and seq_len is None:
|
|
1376
|
-
if inputs_embeds is not None:
|
|
1377
|
-
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
1378
|
-
else:
|
|
1379
|
-
batch_size, seq_len = input_ids.shape[:2]
|
|
1380
|
-
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
1381
|
-
|
|
1382
|
-
if attention_mask is None:
|
|
1383
|
-
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
|
|
1384
|
-
|
|
1385
783
|
outputs = self.model(
|
|
1386
784
|
input_ids=input_ids,
|
|
1387
785
|
attention_mask=attention_mask,
|
|
1388
|
-
sliding_window_mask=sliding_window_mask,
|
|
1389
786
|
position_ids=position_ids,
|
|
1390
787
|
inputs_embeds=inputs_embeds,
|
|
1391
|
-
|
|
1392
|
-
cu_seqlens=cu_seqlens,
|
|
1393
|
-
max_seqlen=max_seqlen,
|
|
1394
|
-
batch_size=batch_size,
|
|
1395
|
-
seq_len=seq_len,
|
|
1396
|
-
output_attentions=output_attentions,
|
|
1397
|
-
output_hidden_states=output_hidden_states,
|
|
1398
|
-
return_dict=return_dict,
|
|
788
|
+
**kwargs,
|
|
1399
789
|
)
|
|
1400
790
|
last_hidden_state = outputs[0]
|
|
1401
791
|
|
|
1402
792
|
if self.config.classifier_pooling == "cls":
|
|
1403
793
|
last_hidden_state = last_hidden_state[:, 0]
|
|
1404
794
|
elif self.config.classifier_pooling == "mean":
|
|
795
|
+
if attention_mask is None:
|
|
796
|
+
attention_mask = torch.ones(
|
|
797
|
+
last_hidden_state.shape[:2], device=last_hidden_state.device, dtype=torch.bool
|
|
798
|
+
)
|
|
1405
799
|
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
|
|
1406
800
|
dim=1, keepdim=True
|
|
1407
801
|
)
|
|
@@ -1433,10 +827,6 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
|
|
|
1433
827
|
loss_fct = BCEWithLogitsLoss()
|
|
1434
828
|
loss = loss_fct(logits, labels)
|
|
1435
829
|
|
|
1436
|
-
if not return_dict:
|
|
1437
|
-
output = (logits,)
|
|
1438
|
-
return ((loss,) + output) if loss is not None else output
|
|
1439
|
-
|
|
1440
830
|
return SequenceClassifierOutput(
|
|
1441
831
|
loss=loss,
|
|
1442
832
|
logits=logits,
|
|
@@ -1463,60 +853,27 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
|
|
1463
853
|
# Initialize weights and apply final processing
|
|
1464
854
|
self.post_init()
|
|
1465
855
|
|
|
856
|
+
@can_return_tuple
|
|
1466
857
|
@auto_docstring
|
|
1467
858
|
def forward(
|
|
1468
859
|
self,
|
|
1469
|
-
input_ids:
|
|
1470
|
-
attention_mask:
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
cu_seqlens: Optional[torch.Tensor] = None,
|
|
1477
|
-
max_seqlen: Optional[int] = None,
|
|
1478
|
-
batch_size: Optional[int] = None,
|
|
1479
|
-
seq_len: Optional[int] = None,
|
|
1480
|
-
output_attentions: Optional[bool] = None,
|
|
1481
|
-
output_hidden_states: Optional[bool] = None,
|
|
1482
|
-
return_dict: Optional[bool] = None,
|
|
1483
|
-
**kwargs,
|
|
1484
|
-
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
|
860
|
+
input_ids: torch.LongTensor | None = None,
|
|
861
|
+
attention_mask: torch.Tensor | None = None,
|
|
862
|
+
position_ids: torch.Tensor | None = None,
|
|
863
|
+
inputs_embeds: torch.Tensor | None = None,
|
|
864
|
+
labels: torch.Tensor | None = None,
|
|
865
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
866
|
+
) -> tuple[torch.Tensor] | TokenClassifierOutput:
|
|
1485
867
|
r"""
|
|
1486
|
-
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1487
|
-
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
1488
|
-
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
1489
|
-
far-away tokens in the local attention layers when not using Flash Attention.
|
|
1490
868
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1491
869
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
|
1492
|
-
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
1493
|
-
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
1494
|
-
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
1495
|
-
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
1496
|
-
max_seqlen (`int`, *optional*):
|
|
1497
|
-
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
1498
|
-
batch_size (`int`, *optional*):
|
|
1499
|
-
Batch size of the input sequences. Used to pad the output tensors.
|
|
1500
|
-
seq_len (`int`, *optional*):
|
|
1501
|
-
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
1502
870
|
"""
|
|
1503
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1504
|
-
self._maybe_set_compile()
|
|
1505
|
-
|
|
1506
871
|
outputs = self.model(
|
|
1507
872
|
input_ids=input_ids,
|
|
1508
873
|
attention_mask=attention_mask,
|
|
1509
|
-
sliding_window_mask=sliding_window_mask,
|
|
1510
874
|
position_ids=position_ids,
|
|
1511
875
|
inputs_embeds=inputs_embeds,
|
|
1512
|
-
|
|
1513
|
-
cu_seqlens=cu_seqlens,
|
|
1514
|
-
max_seqlen=max_seqlen,
|
|
1515
|
-
batch_size=batch_size,
|
|
1516
|
-
seq_len=seq_len,
|
|
1517
|
-
output_attentions=output_attentions,
|
|
1518
|
-
output_hidden_states=output_hidden_states,
|
|
1519
|
-
return_dict=return_dict,
|
|
876
|
+
**kwargs,
|
|
1520
877
|
)
|
|
1521
878
|
last_hidden_state = outputs[0]
|
|
1522
879
|
|
|
@@ -1529,10 +886,6 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
|
|
1529
886
|
loss_fct = CrossEntropyLoss()
|
|
1530
887
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
1531
888
|
|
|
1532
|
-
if not return_dict:
|
|
1533
|
-
output = (logits,) + outputs[1:]
|
|
1534
|
-
return ((loss,) + output) if loss is not None else output
|
|
1535
|
-
|
|
1536
889
|
return TokenClassifierOutput(
|
|
1537
890
|
loss=loss,
|
|
1538
891
|
logits=logits,
|
|
@@ -1554,57 +907,22 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
|
1554
907
|
|
|
1555
908
|
self.post_init()
|
|
1556
909
|
|
|
910
|
+
@can_return_tuple
|
|
1557
911
|
@auto_docstring
|
|
1558
912
|
def forward(
|
|
1559
913
|
self,
|
|
1560
|
-
input_ids:
|
|
1561
|
-
attention_mask:
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
cu_seqlens: Optional[torch.Tensor] = None,
|
|
1568
|
-
max_seqlen: Optional[int] = None,
|
|
1569
|
-
batch_size: Optional[int] = None,
|
|
1570
|
-
seq_len: Optional[int] = None,
|
|
1571
|
-
output_attentions: Optional[bool] = None,
|
|
1572
|
-
output_hidden_states: Optional[bool] = None,
|
|
1573
|
-
return_dict: Optional[bool] = None,
|
|
1574
|
-
**kwargs,
|
|
1575
|
-
) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
|
1576
|
-
r"""
|
|
1577
|
-
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1578
|
-
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
1579
|
-
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
1580
|
-
far-away tokens in the local attention layers when not using Flash Attention.
|
|
1581
|
-
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
1582
|
-
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
1583
|
-
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
1584
|
-
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
1585
|
-
max_seqlen (`int`, *optional*):
|
|
1586
|
-
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
1587
|
-
batch_size (`int`, *optional*):
|
|
1588
|
-
Batch size of the input sequences. Used to pad the output tensors.
|
|
1589
|
-
seq_len (`int`, *optional*):
|
|
1590
|
-
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
1591
|
-
"""
|
|
1592
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1593
|
-
self._maybe_set_compile()
|
|
1594
|
-
|
|
914
|
+
input_ids: torch.Tensor | None = None,
|
|
915
|
+
attention_mask: torch.Tensor | None = None,
|
|
916
|
+
position_ids: torch.Tensor | None = None,
|
|
917
|
+
start_positions: torch.Tensor | None = None,
|
|
918
|
+
end_positions: torch.Tensor | None = None,
|
|
919
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
920
|
+
) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
|
|
1595
921
|
outputs = self.model(
|
|
1596
922
|
input_ids,
|
|
1597
923
|
attention_mask=attention_mask,
|
|
1598
|
-
sliding_window_mask=sliding_window_mask,
|
|
1599
924
|
position_ids=position_ids,
|
|
1600
|
-
|
|
1601
|
-
cu_seqlens=cu_seqlens,
|
|
1602
|
-
max_seqlen=max_seqlen,
|
|
1603
|
-
batch_size=batch_size,
|
|
1604
|
-
seq_len=seq_len,
|
|
1605
|
-
output_attentions=output_attentions,
|
|
1606
|
-
output_hidden_states=output_hidden_states,
|
|
1607
|
-
return_dict=return_dict,
|
|
925
|
+
**kwargs,
|
|
1608
926
|
)
|
|
1609
927
|
last_hidden_state = outputs[0]
|
|
1610
928
|
|
|
@@ -1620,10 +938,6 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
|
1620
938
|
if start_positions is not None and end_positions is not None:
|
|
1621
939
|
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
|
1622
940
|
|
|
1623
|
-
if not return_dict:
|
|
1624
|
-
output = (start_logits, end_logits) + outputs[1:]
|
|
1625
|
-
return ((loss,) + output) if loss is not None else output
|
|
1626
|
-
|
|
1627
941
|
return QuestionAnsweringModelOutput(
|
|
1628
942
|
loss=loss,
|
|
1629
943
|
start_logits=start_logits,
|
|
@@ -1651,45 +965,22 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
|
|
|
1651
965
|
# Initialize weights and apply final processing
|
|
1652
966
|
self.post_init()
|
|
1653
967
|
|
|
968
|
+
@can_return_tuple
|
|
1654
969
|
@auto_docstring
|
|
1655
970
|
def forward(
|
|
1656
971
|
self,
|
|
1657
|
-
input_ids:
|
|
1658
|
-
attention_mask:
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
|
-
|
|
1663
|
-
|
|
1664
|
-
cu_seqlens: Optional[torch.Tensor] = None,
|
|
1665
|
-
max_seqlen: Optional[int] = None,
|
|
1666
|
-
batch_size: Optional[int] = None,
|
|
1667
|
-
seq_len: Optional[int] = None,
|
|
1668
|
-
output_attentions: Optional[bool] = None,
|
|
1669
|
-
output_hidden_states: Optional[bool] = None,
|
|
1670
|
-
return_dict: Optional[bool] = None,
|
|
1671
|
-
**kwargs,
|
|
1672
|
-
) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
|
972
|
+
input_ids: torch.LongTensor | None = None,
|
|
973
|
+
attention_mask: torch.Tensor | None = None,
|
|
974
|
+
position_ids: torch.Tensor | None = None,
|
|
975
|
+
inputs_embeds: torch.Tensor | None = None,
|
|
976
|
+
labels: torch.Tensor | None = None,
|
|
977
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
978
|
+
) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
|
|
1673
979
|
r"""
|
|
1674
|
-
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1675
|
-
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
1676
|
-
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
1677
|
-
far-away tokens in the local attention layers when not using Flash Attention.
|
|
1678
980
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
1679
981
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
|
1680
982
|
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
|
|
1681
|
-
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
1682
|
-
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
1683
|
-
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
1684
|
-
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
1685
|
-
max_seqlen (`int`, *optional*):
|
|
1686
|
-
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
1687
|
-
batch_size (`int`, *optional*):
|
|
1688
|
-
Batch size of the input sequences. Used to pad the output tensors.
|
|
1689
|
-
seq_len (`int`, *optional*):
|
|
1690
|
-
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
1691
983
|
"""
|
|
1692
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1693
984
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
|
1694
985
|
|
|
1695
986
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
|
@@ -1701,22 +992,12 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
|
|
|
1701
992
|
else None
|
|
1702
993
|
)
|
|
1703
994
|
|
|
1704
|
-
self._maybe_set_compile()
|
|
1705
|
-
|
|
1706
995
|
outputs = self.model(
|
|
1707
996
|
input_ids=input_ids,
|
|
1708
997
|
attention_mask=attention_mask,
|
|
1709
|
-
sliding_window_mask=sliding_window_mask,
|
|
1710
998
|
position_ids=position_ids,
|
|
1711
999
|
inputs_embeds=inputs_embeds,
|
|
1712
|
-
|
|
1713
|
-
cu_seqlens=cu_seqlens,
|
|
1714
|
-
max_seqlen=max_seqlen,
|
|
1715
|
-
batch_size=batch_size,
|
|
1716
|
-
seq_len=seq_len,
|
|
1717
|
-
output_attentions=output_attentions,
|
|
1718
|
-
output_hidden_states=output_hidden_states,
|
|
1719
|
-
return_dict=return_dict,
|
|
1000
|
+
**kwargs,
|
|
1720
1001
|
)
|
|
1721
1002
|
last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
|
|
1722
1003
|
|
|
@@ -1748,10 +1029,6 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
|
|
|
1748
1029
|
loss_fct = nn.CrossEntropyLoss()
|
|
1749
1030
|
loss = loss_fct(reshaped_logits, labels)
|
|
1750
1031
|
|
|
1751
|
-
if not return_dict:
|
|
1752
|
-
output = (reshaped_logits,) + outputs[1:]
|
|
1753
|
-
return ((loss,) + output) if loss is not None else output
|
|
1754
|
-
|
|
1755
1032
|
return MultipleChoiceModelOutput(
|
|
1756
1033
|
loss=loss,
|
|
1757
1034
|
logits=reshaped_logits,
|