transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -55,7 +55,7 @@ class Scheduler(ABC):
|
|
|
55
55
|
self.waiting_requests_order.append(state.request_id)
|
|
56
56
|
|
|
57
57
|
@abstractmethod
|
|
58
|
-
def schedule_batch(self, token_budget: int, cache_budget: int) -> list[RequestState]:
|
|
58
|
+
def schedule_batch(self, token_budget: int, cache_budget: int) -> list[RequestState] | None:
|
|
59
59
|
"""Schedules requests for the next batch based on available token and cache budgets. This method selects which
|
|
60
60
|
requests should be processed in the current batch, considering the budgets and the scheduler's prioritization
|
|
61
61
|
rules. The token_budget is the maximum number of tokens that can be processed in a batch, and the cache_budget
|
|
@@ -64,7 +64,7 @@ class Scheduler(ABC):
|
|
|
64
64
|
@traced
|
|
65
65
|
def has_pending_requests(self) -> bool:
|
|
66
66
|
"""Checks if there are requests ready to be processed."""
|
|
67
|
-
return len(self.active_requests) or len(self.waiting_requests)
|
|
67
|
+
return bool(len(self.active_requests) or len(self.waiting_requests))
|
|
68
68
|
|
|
69
69
|
@traced
|
|
70
70
|
def finish_request(self, request_id: str, evict_from_cache: bool = True) -> None:
|
|
@@ -160,9 +160,11 @@ class Scheduler(ABC):
|
|
|
160
160
|
request_ids_to_remove_from_waiting: set[str],
|
|
161
161
|
) -> None:
|
|
162
162
|
"""Schedules a request for the current batch, updating the request's status according to the token budget left.
|
|
163
|
+
After a request is scheduled, it is part of the next batch unless there is an error.
|
|
163
164
|
If the request has children (for parallel decoding), it ensures at least one token remains before the request is
|
|
164
165
|
forked."""
|
|
165
166
|
# If the request has one or more children we make sure not to prefill it entirely
|
|
167
|
+
# This does not check the request state, but DECODING request already have children set to 0.
|
|
166
168
|
if state.num_children > 0 and token_budget >= len(request_tokens) - 1:
|
|
167
169
|
token_budget = len(request_tokens) - 1
|
|
168
170
|
self._requests_to_fork.append(state)
|
|
@@ -189,48 +191,27 @@ class Scheduler(ABC):
|
|
|
189
191
|
state.remaining_prefill_tokens = request_tokens[token_budget:]
|
|
190
192
|
state.tokens_to_process = request_tokens[:token_budget]
|
|
191
193
|
|
|
194
|
+
def _process_candidates(
|
|
195
|
+
self,
|
|
196
|
+
candidates: list[RequestState],
|
|
197
|
+
token_budget: int,
|
|
198
|
+
cache_budget: int,
|
|
199
|
+
request_ids_to_remove_from_waiting: set[str],
|
|
200
|
+
safety_margin: float = 0.0,
|
|
201
|
+
) -> tuple[list[RequestState], bool]:
|
|
202
|
+
"""Schedules candidate requests for the current batch.
|
|
192
203
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
"""This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
|
|
197
|
-
prefilling requests. Additionally, it includes a safety margin mechanism to prevent cache exhaustion. By default,
|
|
198
|
-
when 80% of the cache is full, new requests will not be scheduled to prioritize decoding active requests."""
|
|
199
|
-
|
|
200
|
-
def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.2):
|
|
201
|
-
"""Initializes the FIFO scheduler. The safety margin is the percentage of free blocks under which we stop
|
|
202
|
-
scheduling new prefill requests, so safety_margin = 0.1 means that when there is less than 10% of free blocks,
|
|
203
|
-
or equivalently when more than 90% of blocks are already allocated, we stop scheduling new prefill requests.
|
|
204
|
+
This method contains the common logic shared by all schedulers: it checks token and cache budgets, allocates
|
|
205
|
+
cache blocks if needed, updates request states, and tracks which waiting requests should be removed from the
|
|
206
|
+
waiting queue.
|
|
204
207
|
"""
|
|
205
|
-
super().__init__(cache, retain_cache_on_finish)
|
|
206
|
-
self.safety_margin = safety_margin
|
|
207
|
-
|
|
208
|
-
@traced
|
|
209
|
-
def schedule_batch(self, token_budget: int, cache_budget: int) -> list[RequestState] | None:
|
|
210
|
-
priority_states: list[RequestState] = []
|
|
211
|
-
second_priority_states: list[RequestState] = []
|
|
212
208
|
scheduled_requests = []
|
|
213
|
-
|
|
214
|
-
for state in self.active_requests.values():
|
|
215
|
-
if state.status == RequestStatus.DECODING:
|
|
216
|
-
priority_states.append(state)
|
|
217
|
-
if state.status in [RequestStatus.SPLIT_PENDING_REMAINDER, RequestStatus.PREFILLING_SPLIT]:
|
|
218
|
-
second_priority_states.append(state)
|
|
219
|
-
|
|
220
|
-
# Add waiting requests to second priority
|
|
221
|
-
if not self.block_new_requests:
|
|
222
|
-
for req_id in self.waiting_requests_order:
|
|
223
|
-
second_priority_states.append(self.waiting_requests[req_id])
|
|
224
|
-
|
|
225
|
-
candidates = priority_states + second_priority_states
|
|
226
|
-
request_ids_to_remove_from_waiting = set()
|
|
227
|
-
safety_margins = self.safety_margin * self.cache.num_blocks
|
|
228
|
-
|
|
229
209
|
one_allocation_failed = False
|
|
210
|
+
safety_margins = safety_margin * self.cache.num_blocks
|
|
230
211
|
|
|
231
212
|
for state in candidates:
|
|
232
|
-
# If we are out the safety margin, we only accept decoding requests or the first prefill request
|
|
233
213
|
num_free_blocks = self.cache.get_num_free_blocks()
|
|
214
|
+
# If we are out the safety margin, we only accept decoding requests or the first prefill request
|
|
234
215
|
outside_safety_margin = num_free_blocks < safety_margins
|
|
235
216
|
if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING:
|
|
236
217
|
logger.info(
|
|
@@ -256,8 +237,8 @@ class FIFOScheduler(Scheduler):
|
|
|
256
237
|
# If the allocation would not be successful, we move on to the next request
|
|
257
238
|
if not allocation_successful:
|
|
258
239
|
one_allocation_failed = True
|
|
259
|
-
# If we
|
|
260
|
-
# allocation as well
|
|
240
|
+
# If we reached a waiting request and the cache is full, all subsequent waiting requests will need
|
|
241
|
+
# allocation as well, so we can safely break out of the scheduling loop.
|
|
261
242
|
if num_free_blocks == 0 and state.request_id in self.waiting_requests:
|
|
262
243
|
logger.info(f"Breaking mid-loop for request {state.request_id} because the cache is full")
|
|
263
244
|
break
|
|
@@ -289,11 +270,59 @@ class FIFOScheduler(Scheduler):
|
|
|
289
270
|
if token_budget == 0 or cache_budget == 0:
|
|
290
271
|
break
|
|
291
272
|
|
|
292
|
-
|
|
273
|
+
return scheduled_requests, one_allocation_failed
|
|
274
|
+
|
|
275
|
+
def _cleanup_waiting_queue(self, request_ids_to_remove_from_waiting: set[str]) -> None:
|
|
276
|
+
"""Removes processed requests from the waiting queue order."""
|
|
293
277
|
self.waiting_requests_order = deque(
|
|
294
278
|
[req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
|
|
295
279
|
)
|
|
296
280
|
|
|
281
|
+
|
|
282
|
+
# TODO: further common-ize the two classes
|
|
283
|
+
@attach_tracer()
|
|
284
|
+
class FIFOScheduler(Scheduler):
|
|
285
|
+
"""This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
|
|
286
|
+
prefilling requests. Additionally, it includes a safety margin mechanism to prevent cache exhaustion. By default,
|
|
287
|
+
when 80% of the cache is full, new requests will not be scheduled to prioritize decoding active requests."""
|
|
288
|
+
|
|
289
|
+
def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.2):
|
|
290
|
+
"""Initializes the FIFO scheduler. The safety margin is the percentage of free blocks under which we stop
|
|
291
|
+
scheduling new prefill requests, so safety_margin = 0.1 means that when there is less than 10% of free blocks,
|
|
292
|
+
or equivalently when more than 90% of blocks are already allocated, we stop scheduling new prefill requests.
|
|
293
|
+
"""
|
|
294
|
+
super().__init__(cache, retain_cache_on_finish)
|
|
295
|
+
self.safety_margin = safety_margin
|
|
296
|
+
|
|
297
|
+
@traced
|
|
298
|
+
def schedule_batch(self, token_budget: int, cache_budget: int) -> list[RequestState] | None:
|
|
299
|
+
priority_states: list[RequestState] = []
|
|
300
|
+
second_priority_states: list[RequestState] = []
|
|
301
|
+
|
|
302
|
+
for state in self.active_requests.values():
|
|
303
|
+
if state.status == RequestStatus.DECODING:
|
|
304
|
+
priority_states.append(state)
|
|
305
|
+
if state.status in [RequestStatus.SPLIT_PENDING_REMAINDER, RequestStatus.PREFILLING_SPLIT]:
|
|
306
|
+
second_priority_states.append(state)
|
|
307
|
+
|
|
308
|
+
# Add waiting requests to second priority
|
|
309
|
+
if not self.block_new_requests:
|
|
310
|
+
for req_id in self.waiting_requests_order:
|
|
311
|
+
second_priority_states.append(self.waiting_requests[req_id])
|
|
312
|
+
|
|
313
|
+
candidates = priority_states + second_priority_states
|
|
314
|
+
request_ids_to_remove_from_waiting = set()
|
|
315
|
+
scheduled_requests, one_allocation_failed = self._process_candidates(
|
|
316
|
+
candidates,
|
|
317
|
+
token_budget,
|
|
318
|
+
cache_budget,
|
|
319
|
+
request_ids_to_remove_from_waiting,
|
|
320
|
+
safety_margin=self.safety_margin,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# We remove waiting requests before checking requests were scheduled, because there might have been prefill matches
|
|
324
|
+
self._cleanup_waiting_queue(request_ids_to_remove_from_waiting)
|
|
325
|
+
|
|
297
326
|
# If no requests were scheduled and the cache is full, we signal it by returning None
|
|
298
327
|
if not scheduled_requests and one_allocation_failed:
|
|
299
328
|
return None
|
|
@@ -313,7 +342,6 @@ class PrefillFirstScheduler(Scheduler):
|
|
|
313
342
|
def schedule_batch(self, token_budget: int, cache_budget: int) -> list[RequestState] | None:
|
|
314
343
|
priority_states: list[RequestState] = []
|
|
315
344
|
second_priority_states: list[RequestState] = []
|
|
316
|
-
scheduled_requests = []
|
|
317
345
|
|
|
318
346
|
for state in self.active_requests.values():
|
|
319
347
|
# XXX: when cache is full, state can stay on `PREFILLING_SPLIT` so we need to take those into account
|
|
@@ -329,62 +357,16 @@ class PrefillFirstScheduler(Scheduler):
|
|
|
329
357
|
|
|
330
358
|
candidates = priority_states + second_priority_states
|
|
331
359
|
request_ids_to_remove_from_waiting = set()
|
|
332
|
-
one_allocation_failed =
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
)
|
|
340
|
-
if cache_budget < cache_needed:
|
|
341
|
-
continue
|
|
342
|
-
|
|
343
|
-
# Infer the tokens that will be present in the batch if token budget is enough
|
|
344
|
-
request_tokens = self._infer_request_tokens(state, request_ids_to_remove_from_waiting)
|
|
345
|
-
# Account for token budget
|
|
346
|
-
request_len = min(len(request_tokens), token_budget)
|
|
347
|
-
# Check there will be enough cache for the new tokens
|
|
348
|
-
allocation_successful = self._allocate_blocks_if_needed(state, request_len)
|
|
349
|
-
|
|
350
|
-
# If the allocation would not be successful, we move on to the next request
|
|
351
|
-
if not allocation_successful:
|
|
352
|
-
one_allocation_failed = True
|
|
353
|
-
# If the request was waiting, all requests afterwards will need allocation, so we break if the cache is full
|
|
354
|
-
if state.request_id in self.waiting_requests and self.cache.get_num_free_blocks() == 0:
|
|
355
|
-
break
|
|
356
|
-
continue
|
|
357
|
-
|
|
358
|
-
# If this point is reached, it means we can safely schedule the request
|
|
359
|
-
self._schedule_request(state, request_tokens, token_budget, request_ids_to_remove_from_waiting)
|
|
360
|
-
request_len = len(state.tokens_to_process) # it may change after scheduling
|
|
361
|
-
scheduled_requests.append(state)
|
|
362
|
-
|
|
363
|
-
# Update the token and cache budgets
|
|
364
|
-
token_budget -= request_len
|
|
365
|
-
cache_budget -= cache_needed
|
|
366
|
-
|
|
367
|
-
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
|
|
368
|
-
if self.cache.allow_block_sharing:
|
|
369
|
-
tokens_in_current_block = state.current_len() % self.cache.block_size
|
|
370
|
-
tokens_after_forward = tokens_in_current_block + request_len
|
|
371
|
-
complete_blocks = tokens_after_forward // self.cache.block_size
|
|
372
|
-
self.cache.blocks_to_complete[state.request_id] = complete_blocks
|
|
373
|
-
|
|
374
|
-
# Remove the request from the waiting queue and mark it as removed
|
|
375
|
-
req_id = state.request_id
|
|
376
|
-
was_waiting = self.waiting_requests.pop(req_id, None) is not None
|
|
377
|
-
if was_waiting:
|
|
378
|
-
request_ids_to_remove_from_waiting.add(req_id)
|
|
379
|
-
|
|
380
|
-
# Early exit of the loop if we have no budget left
|
|
381
|
-
if token_budget == 0 or cache_budget == 0:
|
|
382
|
-
break
|
|
360
|
+
scheduled_requests, one_allocation_failed = self._process_candidates(
|
|
361
|
+
candidates,
|
|
362
|
+
token_budget,
|
|
363
|
+
cache_budget,
|
|
364
|
+
request_ids_to_remove_from_waiting,
|
|
365
|
+
safety_margin=0.0,
|
|
366
|
+
)
|
|
383
367
|
|
|
384
368
|
# We remove waiting requests before checking requests were scheduled, because there might have been prefill matches
|
|
385
|
-
self.
|
|
386
|
-
[req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
|
|
387
|
-
)
|
|
369
|
+
self._cleanup_waiting_queue(request_ids_to_remove_from_waiting)
|
|
388
370
|
|
|
389
371
|
# If no requests were scheduled and the cache is full, we signal it by returning None
|
|
390
372
|
if not scheduled_requests and one_allocation_failed:
|
|
@@ -20,7 +20,6 @@ from typing import TYPE_CHECKING
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
import torch
|
|
22
22
|
|
|
23
|
-
from ..pytorch_utils import isin_mps_friendly
|
|
24
23
|
from ..utils import add_start_docstrings
|
|
25
24
|
from ..utils.logging import get_logger
|
|
26
25
|
|
|
@@ -93,6 +92,12 @@ class LogitsProcessorList(list):
|
|
|
93
92
|
|
|
94
93
|
return scores
|
|
95
94
|
|
|
95
|
+
def set_continuous_batching_context(self, logits_indices: torch.Tensor, cu_seq_lens_q: torch.Tensor) -> None:
|
|
96
|
+
"""Forwards the continuous batching metadata to all logit processors that need it."""
|
|
97
|
+
for processor in self:
|
|
98
|
+
if hasattr(processor, "set_continuous_batching_context"):
|
|
99
|
+
processor.set_continuous_batching_context(logits_indices, cu_seq_lens_q)
|
|
100
|
+
|
|
96
101
|
|
|
97
102
|
class MinLengthLogitsProcessor(LogitsProcessor):
|
|
98
103
|
r"""
|
|
@@ -148,7 +153,7 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
|
|
148
153
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
|
149
154
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
150
155
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
|
151
|
-
eos_token_mask =
|
|
156
|
+
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
|
|
152
157
|
scores_processed = scores.clone()
|
|
153
158
|
if input_ids.shape[-1] < self.min_length:
|
|
154
159
|
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
|
@@ -220,7 +225,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
|
|
220
225
|
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
|
221
226
|
scores_processed = scores.clone()
|
|
222
227
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
|
223
|
-
eos_token_mask =
|
|
228
|
+
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
|
|
224
229
|
if new_tokens_length < self.min_new_tokens:
|
|
225
230
|
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
|
226
231
|
|
|
@@ -1847,7 +1852,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
|
|
|
1847
1852
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
|
1848
1853
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
1849
1854
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
|
1850
|
-
suppress_token_mask =
|
|
1855
|
+
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
|
|
1851
1856
|
scores_processed = scores
|
|
1852
1857
|
if input_ids.shape[-1] == self.begin_index:
|
|
1853
1858
|
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
|
|
@@ -1890,7 +1895,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
|
|
|
1890
1895
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
|
1891
1896
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
1892
1897
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
|
1893
|
-
suppress_token_mask =
|
|
1898
|
+
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens.to(scores.device))
|
|
1894
1899
|
scores = torch.where(suppress_token_mask, -float("inf"), scores)
|
|
1895
1900
|
return scores
|
|
1896
1901
|
|
|
@@ -8,7 +8,6 @@ import numpy as np
|
|
|
8
8
|
import torch
|
|
9
9
|
from torch.nn import functional as F
|
|
10
10
|
|
|
11
|
-
from ..pytorch_utils import isin_mps_friendly
|
|
12
11
|
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
|
13
12
|
from ..utils import add_start_docstrings, logging
|
|
14
13
|
|
|
@@ -468,7 +467,7 @@ class EosTokenCriteria(StoppingCriteria):
|
|
|
468
467
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
|
469
468
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
|
470
469
|
self.eos_token_id = self.eos_token_id.to(input_ids.device)
|
|
471
|
-
is_done =
|
|
470
|
+
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
|
|
472
471
|
return is_done
|
|
473
472
|
|
|
474
473
|
|
transformers/generation/utils.py
CHANGED
|
@@ -42,17 +42,15 @@ from ..dynamic_module_utils import (
|
|
|
42
42
|
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
43
43
|
from ..integrations.fsdp import is_fsdp_managed_module
|
|
44
44
|
from ..masking_utils import create_masks_for_generate
|
|
45
|
-
from ..pytorch_utils import isin_mps_friendly
|
|
46
45
|
from ..tokenization_python import ExtensionsTrie
|
|
47
46
|
from ..utils import (
|
|
48
47
|
ModelOutput,
|
|
49
48
|
TransformersKwargs,
|
|
50
49
|
is_accelerate_available,
|
|
51
|
-
is_hqq_available,
|
|
52
|
-
is_optimum_quanto_available,
|
|
53
50
|
is_torchdynamo_exporting,
|
|
54
51
|
logging,
|
|
55
52
|
)
|
|
53
|
+
from ..utils.generic import is_flash_attention_requested
|
|
56
54
|
from .candidate_generator import (
|
|
57
55
|
AssistantVocabTranslatorCache,
|
|
58
56
|
AssistedCandidateGenerator,
|
|
@@ -861,11 +859,9 @@ class GenerationMixin(ContinuousMixin):
|
|
|
861
859
|
if not is_input_ids:
|
|
862
860
|
return default_attention_mask
|
|
863
861
|
|
|
864
|
-
is_pad_token_in_inputs = (pad_token_id is not None) and (
|
|
865
|
-
isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
|
|
866
|
-
)
|
|
862
|
+
is_pad_token_in_inputs = (pad_token_id is not None) and (torch.isin(inputs_tensor, pad_token_id).any())
|
|
867
863
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
|
|
868
|
-
|
|
864
|
+
torch.isin(eos_token_id, pad_token_id).any()
|
|
869
865
|
)
|
|
870
866
|
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
|
871
867
|
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
|
|
@@ -1772,9 +1768,9 @@ class GenerationMixin(ContinuousMixin):
|
|
|
1772
1768
|
"""
|
|
1773
1769
|
# parameterization priority:
|
|
1774
1770
|
# user-defined kwargs or `generation_config` > `self.generation_config` > global default values
|
|
1775
|
-
# TODO: (raushan) doesn't make sense to allow kwargs and `generation_config`. Should be mutually exclusive!
|
|
1776
1771
|
# TODO (joao): per-model generation config classes.
|
|
1777
1772
|
|
|
1773
|
+
generation_config_provided = generation_config is not None
|
|
1778
1774
|
if generation_config is None:
|
|
1779
1775
|
# Users may modify `model.config` to control generation. This is a legacy behavior and is not supported anymore
|
|
1780
1776
|
if len(self.config._get_generation_parameters()) > 0:
|
|
@@ -1810,6 +1806,16 @@ class GenerationMixin(ContinuousMixin):
|
|
|
1810
1806
|
if generation_config.cache_implementation == "hybrid":
|
|
1811
1807
|
generation_config.cache_implementation = None
|
|
1812
1808
|
|
|
1809
|
+
# It doesn't make sense to allow kwargs and `generation_config`, that should be mutually exclusive
|
|
1810
|
+
if generation_config_provided and set(kwargs.keys()) - set(model_kwargs.keys()):
|
|
1811
|
+
generation_kwargs = set(kwargs.keys()) - set(model_kwargs.keys())
|
|
1812
|
+
logger.warning_once(
|
|
1813
|
+
f"Passing `generation_config` together with generation-related "
|
|
1814
|
+
f"arguments=({generation_kwargs}) is deprecated and will be removed in future versions. "
|
|
1815
|
+
"Please pass either a `generation_config` object OR all generation "
|
|
1816
|
+
"parameters explicitly, but not both.",
|
|
1817
|
+
)
|
|
1818
|
+
|
|
1813
1819
|
# Finally keep output_xxx args in `model_kwargs` so it can be passed to `forward`
|
|
1814
1820
|
output_attentions = generation_config.output_attentions
|
|
1815
1821
|
output_hidden_states = generation_config.output_hidden_states
|
|
@@ -1847,20 +1853,19 @@ class GenerationMixin(ContinuousMixin):
|
|
|
1847
1853
|
model_kwargs["cache_position"] = cache_position
|
|
1848
1854
|
return model_kwargs
|
|
1849
1855
|
|
|
1850
|
-
def
|
|
1856
|
+
def _prepare_static_cache(
|
|
1857
|
+
self, cache_implementation: str, batch_size: int, max_cache_len: int, model_kwargs
|
|
1858
|
+
) -> Cache:
|
|
1851
1859
|
"""
|
|
1852
1860
|
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
|
|
1853
1861
|
new `generate` call requires a larger cache or uses a different batch size.
|
|
1854
1862
|
|
|
1855
1863
|
Returns the resulting cache object.
|
|
1856
1864
|
"""
|
|
1857
|
-
requires_cross_attention_cache = (
|
|
1858
|
-
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
|
1859
|
-
)
|
|
1860
1865
|
offload_cache = "offloaded" in cache_implementation
|
|
1861
1866
|
|
|
1862
1867
|
if hasattr(self, "_cache"):
|
|
1863
|
-
cache_to_check = self._cache.self_attention_cache if
|
|
1868
|
+
cache_to_check = self._cache.self_attention_cache if self.config.is_encoder_decoder else self._cache
|
|
1864
1869
|
|
|
1865
1870
|
need_new_cache = (
|
|
1866
1871
|
not hasattr(self, "_cache")
|
|
@@ -1869,7 +1874,7 @@ class GenerationMixin(ContinuousMixin):
|
|
|
1869
1874
|
or cache_to_check.max_cache_len < max_cache_len
|
|
1870
1875
|
)
|
|
1871
1876
|
|
|
1872
|
-
if
|
|
1877
|
+
if self.config.is_encoder_decoder and hasattr(self, "_cache"):
|
|
1873
1878
|
need_new_cache = (
|
|
1874
1879
|
need_new_cache
|
|
1875
1880
|
or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
|
|
@@ -1882,7 +1887,7 @@ class GenerationMixin(ContinuousMixin):
|
|
|
1882
1887
|
"offloading": offload_cache,
|
|
1883
1888
|
}
|
|
1884
1889
|
self._cache = StaticCache(**self_attention_cache_kwargs)
|
|
1885
|
-
if
|
|
1890
|
+
if self.config.is_encoder_decoder:
|
|
1886
1891
|
cross_attention_cache_kwargs = {
|
|
1887
1892
|
"config": self.config.get_text_config(decoder=True),
|
|
1888
1893
|
"max_cache_len": model_kwargs["encoder_outputs"][0].shape[1],
|
|
@@ -1925,12 +1930,9 @@ class GenerationMixin(ContinuousMixin):
|
|
|
1925
1930
|
instantiated, writes it to `model_kwargs`, under the name expected by the model.
|
|
1926
1931
|
"""
|
|
1927
1932
|
|
|
1928
|
-
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
requires_cross_attention_cache = (
|
|
1932
|
-
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
|
1933
|
-
)
|
|
1933
|
+
# TODO @raushan, unify cache arg naming for all models
|
|
1934
|
+
is_linear_attn_cache = "mamba" in self.__class__.__name__.lower()
|
|
1935
|
+
cache_name = "past_key_values" if not is_linear_attn_cache else "cache_params"
|
|
1934
1936
|
|
|
1935
1937
|
# Quick escape route 1: if the user specifies a cache, we only need to check for conflicting `generate` arguments
|
|
1936
1938
|
user_defined_cache = model_kwargs.get(cache_name)
|
|
@@ -1962,76 +1964,55 @@ class GenerationMixin(ContinuousMixin):
|
|
|
1962
1964
|
|
|
1963
1965
|
# Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`
|
|
1964
1966
|
|
|
1965
|
-
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
|
|
1966
|
-
# which is only supported in dynamic caches atm
|
|
1967
|
-
if (
|
|
1968
|
-
generation_mode == GenerationMode.ASSISTED_GENERATION
|
|
1969
|
-
and generation_config.cache_implementation is not None
|
|
1970
|
-
):
|
|
1971
|
-
logger.warning_once(
|
|
1972
|
-
"An assistant model is provided, using a dynamic cache instead of a cache of type="
|
|
1973
|
-
f"'{generation_config.cache_implementation}'."
|
|
1974
|
-
)
|
|
1975
|
-
generation_config.cache_implementation = None
|
|
1976
|
-
|
|
1977
1967
|
# Assisted decoding and contrastive search require cache rollback, which is incompatible with sliding layers.
|
|
1978
1968
|
# To handle this, we skip passing the model config to DynamicCache (forcing a full-layer cache).
|
|
1979
1969
|
# The "dynamic_full" option is a shortcut for generate() users to avoid sliding layers on their own.
|
|
1980
|
-
if (
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
else:
|
|
1986
|
-
dynamic_cache_kwargs = {"config": self.config.get_text_config(decoder=True)}
|
|
1987
|
-
if generation_config.cache_implementation is not None:
|
|
1988
|
-
if generation_config.cache_implementation in ALL_STATIC_CACHE_IMPLEMENTATIONS:
|
|
1989
|
-
if generation_config.cache_implementation in DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS:
|
|
1990
|
-
logger.warning_once(
|
|
1991
|
-
f"Using `cache_implementation='{generation_config.cache_implementation}' is deprecated. "
|
|
1992
|
-
f"Please only use one of {STATIC_CACHE_IMPLEMENTATIONS}, and the layer structure will be "
|
|
1993
|
-
"inferred automatically."
|
|
1994
|
-
)
|
|
1995
|
-
model_kwargs[cache_name] = self._get_cache(
|
|
1996
|
-
cache_implementation=generation_config.cache_implementation,
|
|
1997
|
-
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
|
|
1998
|
-
max_cache_len=max_cache_length,
|
|
1999
|
-
model_kwargs=model_kwargs,
|
|
1970
|
+
if generation_mode in (GenerationMode.ASSISTED_GENERATION, GenerationMode.CONTRASTIVE_SEARCH):
|
|
1971
|
+
if generation_config.cache_implementation is not None:
|
|
1972
|
+
logger.warning_once(
|
|
1973
|
+
"An assistant model is provided, using a dynamic cache instead of a cache of type="
|
|
1974
|
+
f"'{generation_config.cache_implementation}'."
|
|
2000
1975
|
)
|
|
2001
|
-
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
|
|
2005
|
-
|
|
2006
|
-
|
|
1976
|
+
generation_config.cache_implementation = "dynamic_full"
|
|
1977
|
+
|
|
1978
|
+
dynamic_cache_kwargs = {}
|
|
1979
|
+
if generation_config.cache_implementation != "dynamic_full":
|
|
1980
|
+
dynamic_cache_kwargs["config"] = self.config.get_text_config(decoder=True)
|
|
1981
|
+
|
|
1982
|
+
if generation_config.cache_implementation == "offloaded":
|
|
1983
|
+
dynamic_cache_kwargs["offloading"] = True
|
|
1984
|
+
|
|
1985
|
+
if generation_config.cache_implementation in ALL_STATIC_CACHE_IMPLEMENTATIONS:
|
|
1986
|
+
if generation_config.cache_implementation in DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS:
|
|
1987
|
+
logger.warning_once(
|
|
1988
|
+
f"Using `cache_implementation='{generation_config.cache_implementation}' is deprecated "
|
|
1989
|
+
f"and will be removed in v5.13. Please only use one of {STATIC_CACHE_IMPLEMENTATIONS}, "
|
|
1990
|
+
"and the layer structure will be inferred automatically."
|
|
1991
|
+
)
|
|
1992
|
+
model_kwargs["past_key_values"] = self._prepare_static_cache(
|
|
1993
|
+
cache_implementation=generation_config.cache_implementation,
|
|
1994
|
+
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
|
|
1995
|
+
max_cache_len=max_cache_length,
|
|
1996
|
+
model_kwargs=model_kwargs,
|
|
1997
|
+
)
|
|
1998
|
+
elif generation_config.cache_implementation == "quantized":
|
|
1999
|
+
if self.config.is_encoder_decoder or not self._supports_default_dynamic_cache():
|
|
2000
|
+
raise ValueError(
|
|
2001
|
+
"This model does not support the quantized cache. If you want your model to support quantized "
|
|
2002
|
+
"cache, please open an issue and tag @zucchini-nlp."
|
|
2003
|
+
)
|
|
2004
|
+
|
|
2005
|
+
cache_config = generation_config.cache_config if generation_config.cache_config is not None else {}
|
|
2006
|
+
cache_config.setdefault("config", self.config.get_text_config(decoder=True))
|
|
2007
|
+
backend = cache_config.pop("backend", "quanto")
|
|
2008
|
+
model_kwargs["past_key_values"] = QuantizedCache(backend=backend, **cache_config)
|
|
2009
|
+
# i.e. `cache_implementation` in [None, "dynamic", "offloaded", "dynamic_full"]
|
|
2010
|
+
# TODO: prepare linear cache from a single API, instead of creating in modeling code
|
|
2011
|
+
else:
|
|
2012
|
+
model_kwargs["past_key_values"] = DynamicCache(**dynamic_cache_kwargs)
|
|
2007
2013
|
|
|
2008
|
-
cache_config = generation_config.cache_config if generation_config.cache_config is not None else {}
|
|
2009
|
-
# Add the config if it was not provided, as it's a required argument
|
|
2010
|
-
if "config" not in cache_config:
|
|
2011
|
-
cache_config["config"] = self.config.get_text_config()
|
|
2012
|
-
# Pop the backend from the config (defaults to quanto if not defined)
|
|
2013
|
-
backend = cache_config.pop("backend", "quanto")
|
|
2014
|
-
|
|
2015
|
-
if backend == "quanto" and not is_optimum_quanto_available():
|
|
2016
|
-
raise ImportError(
|
|
2017
|
-
"You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto "
|
|
2018
|
-
"backend. Please install it via with `pip install optimum-quanto`"
|
|
2019
|
-
)
|
|
2020
|
-
elif backend == "HQQ" and not is_hqq_available():
|
|
2021
|
-
raise ImportError(
|
|
2022
|
-
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
|
|
2023
|
-
"Please install it via with `pip install hqq`"
|
|
2024
|
-
)
|
|
2025
|
-
model_kwargs[cache_name] = QuantizedCache(backend=backend, **cache_config)
|
|
2026
|
-
elif generation_config.cache_implementation == "offloaded":
|
|
2027
|
-
model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs, offloading=True)
|
|
2028
|
-
elif "dynamic" in generation_config.cache_implementation:
|
|
2029
|
-
model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)
|
|
2030
|
-
|
|
2031
|
-
# TODO (joao): this logic is incomplete, e.g. `offloaded` should apply to both caches. Refactor this function
|
|
2032
|
-
# to correctly pass parameterization to both caches.
|
|
2033
2014
|
if (
|
|
2034
|
-
|
|
2015
|
+
self.config.is_encoder_decoder
|
|
2035
2016
|
and "past_key_values" in model_kwargs
|
|
2036
2017
|
and not isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
|
|
2037
2018
|
):
|
|
@@ -2102,10 +2083,7 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2102
2083
|
raise ValueError(
|
|
2103
2084
|
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
|
2104
2085
|
)
|
|
2105
|
-
if (
|
|
2106
|
-
eos_token_tensor is not None
|
|
2107
|
-
and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
|
|
2108
|
-
):
|
|
2086
|
+
if eos_token_tensor is not None and torch.isin(eos_token_tensor, pad_token_tensor).any():
|
|
2109
2087
|
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
|
2110
2088
|
logger.warning_once(
|
|
2111
2089
|
"The attention mask is not set and cannot be inferred from input because pad token is same as "
|
|
@@ -2172,13 +2150,13 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2172
2150
|
# Finally: if we can compile, disable tokenizers parallelism
|
|
2173
2151
|
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
|
2174
2152
|
|
|
2175
|
-
# If we use
|
|
2176
|
-
if self.config
|
|
2153
|
+
# If we use FA and a static cache, we cannot compile with fullgraph
|
|
2154
|
+
if is_flash_attention_requested(self.config):
|
|
2177
2155
|
# only raise warning if the user passed an explicit compile-config
|
|
2178
2156
|
if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
|
|
2179
2157
|
logger.warning_once(
|
|
2180
|
-
"When using Flash Attention
|
|
2181
|
-
"
|
|
2158
|
+
"When using Flash Attention and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
|
|
2159
|
+
"FA introduces graph breaks. We overrode the option with `fullgraph=False`."
|
|
2182
2160
|
)
|
|
2183
2161
|
generation_config.compile_config.fullgraph = False
|
|
2184
2162
|
|
|
@@ -2187,7 +2165,9 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2187
2165
|
@contextmanager
|
|
2188
2166
|
def _optimize_model_for_decode(self):
|
|
2189
2167
|
original_experts_implementation = self.config._experts_implementation
|
|
2190
|
-
|
|
2168
|
+
# On non-CPU devices, 'batched_mm' can trade off a bit of memory (by duplicating selected experts weights)
|
|
2169
|
+
# for much better speed during decoding, especially for smaller inputs. On CPU, grouped_mm is usually better.
|
|
2170
|
+
if original_experts_implementation == "grouped_mm" and self.device.type != "cpu":
|
|
2191
2171
|
logger.info_once(
|
|
2192
2172
|
"We will be switching to 'batched_mm' for the decoding stage as it is much more performant than 'grouped_mm' on smaller inputs. "
|
|
2193
2173
|
"If you experience any issues with this, please open an issue on the Hugging Face Transformers GitHub repository.",
|
|
@@ -2197,7 +2177,7 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2197
2177
|
try:
|
|
2198
2178
|
yield
|
|
2199
2179
|
finally:
|
|
2200
|
-
if original_experts_implementation == "grouped_mm":
|
|
2180
|
+
if original_experts_implementation == "grouped_mm" and self.device.type != "cpu":
|
|
2201
2181
|
self.set_experts_implementation(original_experts_implementation)
|
|
2202
2182
|
|
|
2203
2183
|
def _get_deprecated_gen_repo(
|