transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
# Copyright 2026 The HuggingFace Inc. team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from functools import partial
|
|
16
|
+
from itertools import count
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
from transformers.configuration_utils import PretrainedConfig
|
|
22
|
+
|
|
23
|
+
from ...utils.metrics import traced
|
|
24
|
+
from .cache import PagedAttentionCache
|
|
25
|
+
from .requests import TMP_TOKEN_ID, RequestState
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def attn_mask_is_needed(config: PretrainedConfig) -> bool:
|
|
29
|
+
"""Checks if attention mask is needed for the given (config)."""
|
|
30
|
+
return config._attn_implementation in ["paged|eager", "paged|sdpa"]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def build_attention_mask(
|
|
34
|
+
attention_mask: torch.Tensor,
|
|
35
|
+
cumulative_seqlens_q: list[int],
|
|
36
|
+
cumulative_seqlens_k: list[int],
|
|
37
|
+
sliding_window: int = 1,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""Builds an attention mask inplace using the cumulative seqlens of the query and key. If given a sliding window, it
|
|
40
|
+
will also apply a sliding window mask on top. The attention mask is not boolean, it uses zeroes and -inf (or its
|
|
41
|
+
equivalent) so it's more of an attention score bias tensor.
|
|
42
|
+
The attention mask is a block-diagonal matrix, with each block an attention mask for a single query-key pair.
|
|
43
|
+
Each of those block is built from a causal mask and, if there is a sliding window, a sliding window mask.
|
|
44
|
+
|
|
45
|
+
An example is represented below, with seqlen_k = 8, seqlen_q = 4 and sliding_window = 6:
|
|
46
|
+
|
|
47
|
+
CAUSAL MASK:
|
|
48
|
+
|
|
49
|
+
█ █ █ █ █ ░ ░ ░
|
|
50
|
+
█ █ █ █ █ █ ░ ░
|
|
51
|
+
█ █ █ █ █ █ █ ░
|
|
52
|
+
█ █ █ █ █ █ █ █
|
|
53
|
+
|
|
54
|
+
SLIDING WINDOW MASK:
|
|
55
|
+
┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 8 - 4 - 6 = -2 offset to the left
|
|
56
|
+
<─┴─>
|
|
57
|
+
░ █ | █ █ █ █ █ █ █ █
|
|
58
|
+
░ ░ | █ █ █ █ █ █ █ █
|
|
59
|
+
░ ░ | ░ █ █ █ █ █ █ █
|
|
60
|
+
░ ░ | ░ ░ █ █ █ █ █ █
|
|
61
|
+
|
|
62
|
+
ATTENTION MASK (sum of causal and sliding window masks):
|
|
63
|
+
|
|
64
|
+
█ █ █ █ █ ░ ░ ░
|
|
65
|
+
█ █ █ █ █ █ ░ ░
|
|
66
|
+
░ █ █ █ █ █ █ ░
|
|
67
|
+
░ ░ █ █ █ █ █ █
|
|
68
|
+
|
|
69
|
+
Another example with seqlen_k = 5, seqlen_q = 3 and sliding_window = 2:
|
|
70
|
+
|
|
71
|
+
CAUSAL MASK:
|
|
72
|
+
|
|
73
|
+
█ █ █ ░ ░
|
|
74
|
+
█ █ █ █ ░
|
|
75
|
+
█ █ █ █ █
|
|
76
|
+
|
|
77
|
+
SLIDING WINDOW MASK:
|
|
78
|
+
┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 5 - 3 - 2 = 0 offset to the left
|
|
79
|
+
<┴>
|
|
80
|
+
| ░ █ █ █ █
|
|
81
|
+
| ░ ░ █ █ █
|
|
82
|
+
| ░ ░ ░ █ █
|
|
83
|
+
|
|
84
|
+
ATTENTION MASK (sum of causal and sliding window masks):
|
|
85
|
+
|
|
86
|
+
░ █ █ ░ ░
|
|
87
|
+
░ ░ █ █ ░
|
|
88
|
+
░ ░ ░ █ █
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
min_value = torch.finfo(attention_mask.dtype).min
|
|
92
|
+
for i in range(len(cumulative_seqlens_q) - 1):
|
|
93
|
+
seqlen_q = cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]
|
|
94
|
+
seqlen_k = cumulative_seqlens_k[i + 1] - cumulative_seqlens_k[i]
|
|
95
|
+
if seqlen_q < seqlen_k and seqlen_q >= 1:
|
|
96
|
+
causal_diagonal = seqlen_k - seqlen_q + 1
|
|
97
|
+
else:
|
|
98
|
+
causal_diagonal = 1
|
|
99
|
+
query_range = slice(cumulative_seqlens_q[i], cumulative_seqlens_q[i + 1])
|
|
100
|
+
key_range = slice(cumulative_seqlens_k[i], cumulative_seqlens_k[i + 1])
|
|
101
|
+
# Apply causal mask
|
|
102
|
+
minus_inf = torch.full(
|
|
103
|
+
attention_mask[..., query_range, key_range].shape,
|
|
104
|
+
min_value,
|
|
105
|
+
dtype=attention_mask.dtype,
|
|
106
|
+
device=attention_mask.device,
|
|
107
|
+
)
|
|
108
|
+
masked = torch.triu(minus_inf, diagonal=causal_diagonal)
|
|
109
|
+
# Apply sliding window mask if needed
|
|
110
|
+
if sliding_window > 1:
|
|
111
|
+
sliding_diagonal = seqlen_k - seqlen_q - sliding_window
|
|
112
|
+
masked += torch.tril(minus_inf, diagonal=sliding_diagonal)
|
|
113
|
+
# Replace in attention mask
|
|
114
|
+
attention_mask[..., query_range, key_range] = masked
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@dataclass
|
|
118
|
+
class PagedAttentionArgs:
|
|
119
|
+
"""Dataclass containing the keyword arguments for a forward pass using paged attention.
|
|
120
|
+
|
|
121
|
+
Attributes:
|
|
122
|
+
input_ids: Input token IDs tensor of shape `(1, total_query_tokens)`.
|
|
123
|
+
attention_mask: Attention mask tensor or dictionary mapping layer types to masks. Can be `None` if the
|
|
124
|
+
attention implementation doesn't require explicit masks.
|
|
125
|
+
position_ids: Position IDs tensor of shape `(1, total_query_tokens)`.
|
|
126
|
+
cu_seq_lens_q: Cumulative sequence lengths for queries, used for variable-length batching.
|
|
127
|
+
cu_seq_lens_k: Cumulative sequence lengths for keys/values. Can be a tensor or dictionary mapping layer
|
|
128
|
+
types (e.g., "full_attention", "sliding_attention") to tensors for hybrid models.
|
|
129
|
+
max_seqlen_q: Maximum query sequence length in the batch.
|
|
130
|
+
max_seqlen_k: Maximum key/value sequence length. Can be an int or dictionary for hybrid models.
|
|
131
|
+
write_index: List of tensors indicating where to write new KV states in the cache, one per attention group.
|
|
132
|
+
read_index: List of tensors indicating which cache positions to read from, one per attention group.
|
|
133
|
+
logits_indices: Tensor indicating which positions in the output should be used for next-token prediction.
|
|
134
|
+
cache: The [`PagedAttentionCache`] instance managing the KV cache.
|
|
135
|
+
use_cache: Whether to use caching (always `False` in continuous batching as the cache is managed externally).
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
input_ids: torch.Tensor
|
|
139
|
+
attention_mask: torch.Tensor | dict[str, torch.Tensor] | None
|
|
140
|
+
position_ids: torch.Tensor
|
|
141
|
+
cu_seq_lens_q: torch.Tensor
|
|
142
|
+
cu_seq_lens_k: torch.Tensor | dict[str, torch.Tensor]
|
|
143
|
+
max_seqlen_q: int
|
|
144
|
+
max_seqlen_k: int | dict[str, int]
|
|
145
|
+
write_index: list[torch.Tensor]
|
|
146
|
+
read_index: list[torch.Tensor]
|
|
147
|
+
logits_indices: torch.Tensor
|
|
148
|
+
cache: PagedAttentionCache
|
|
149
|
+
use_cache: bool = False
|
|
150
|
+
|
|
151
|
+
def asdict(self) -> dict[str, Any]:
|
|
152
|
+
return {
|
|
153
|
+
"input_ids": self.input_ids,
|
|
154
|
+
"attention_mask": self.attention_mask,
|
|
155
|
+
"position_ids": self.position_ids,
|
|
156
|
+
"cu_seq_lens_q": self.cu_seq_lens_q,
|
|
157
|
+
"cu_seq_lens_k": self.cu_seq_lens_k,
|
|
158
|
+
"max_seqlen_q": self.max_seqlen_q,
|
|
159
|
+
"max_seqlen_k": self.max_seqlen_k,
|
|
160
|
+
"write_index": self.write_index,
|
|
161
|
+
"read_index": self.read_index,
|
|
162
|
+
"logits_indices": self.logits_indices,
|
|
163
|
+
"cache": self.cache,
|
|
164
|
+
"use_cache": self.use_cache,
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class ContinuousBatchingIOs:
|
|
169
|
+
"""Manages input/output tensors for continuous batching generation. This class handles the allocation and management
|
|
170
|
+
of static tensors used during generation steps in continuous batching mode. Allocation is done once at init time.
|
|
171
|
+
|
|
172
|
+
The class is responsible for:
|
|
173
|
+
- Setting up static tensor storage for all generation inputs/outputs
|
|
174
|
+
- Preparing batch tensors from a list of request states before each forward pass
|
|
175
|
+
- Building model keyword arguments with optional padding for CUDA graphs/torch.compile
|
|
176
|
+
- Resetting tensors between batches while minimizing memory operations
|
|
177
|
+
|
|
178
|
+
It keeps track of the requests in the current batch as well as the actual number of tokens (Q and KV), sequences in
|
|
179
|
+
the batch and sizes of indices. This is useful when using padded inputs, for CUDA graphs and/or torch.compile.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
def __init__(
|
|
183
|
+
self, cache: PagedAttentionCache, config: PretrainedConfig, device: torch.device, model_dtype: torch.dtype
|
|
184
|
+
) -> None:
|
|
185
|
+
"""Initialize the continuous batching I/O manager.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
cache: The [`PagedAttentionCache`] instance managing the KV cache.
|
|
189
|
+
config: The model's pretrained configuration.
|
|
190
|
+
device: The device to allocate tensors on.
|
|
191
|
+
model_dtype: The data type for model computations.
|
|
192
|
+
"""
|
|
193
|
+
# Memoize attributes
|
|
194
|
+
self.cache = cache
|
|
195
|
+
self.device = device
|
|
196
|
+
self.config = config
|
|
197
|
+
self.model_dtype = model_dtype
|
|
198
|
+
self.sliding_window = 1 if getattr(config, "sliding_window", None) is None else config.sliding_window
|
|
199
|
+
# Setup accumulators
|
|
200
|
+
self.requests_in_batch: list[RequestState] = []
|
|
201
|
+
self.actual_query_length = 0
|
|
202
|
+
self.actual_key_length = 0
|
|
203
|
+
self.actual_batch_size = 0
|
|
204
|
+
self.actual_index_sizes = [(0, 0) for _ in range(cache.num_groups)]
|
|
205
|
+
# Setup static tensors
|
|
206
|
+
self.setup_static_tensors()
|
|
207
|
+
self.reset_static_tensors(full_reset=True)
|
|
208
|
+
|
|
209
|
+
@traced(standalone=True)
|
|
210
|
+
def setup_static_tensors(self) -> None:
|
|
211
|
+
"""Allocates static tensors for generation inputs and outputs. This is called only once at init time, to avoid
|
|
212
|
+
repeated allocations and enable CUDA graphs. All tensors are allocated with maximum possible sizes.
|
|
213
|
+
The allocated tensors are:
|
|
214
|
+
|
|
215
|
+
- `input_ids` and `position_ids`: Query token information
|
|
216
|
+
- `cumulative_seqlens_q` and `cumulative_seqlens_k`: Sequence length tracking for FlashAttention-style batching
|
|
217
|
+
- `attention_mask`: Optional attention masks (only for eager/SDPA implementations)
|
|
218
|
+
- `write_index` and `read_index` storage: Cache indexing tensors for each attention group
|
|
219
|
+
- `output_ids`: Storage for generated token IDs
|
|
220
|
+
"""
|
|
221
|
+
num_pages = self.cache.num_blocks * self.cache.block_size
|
|
222
|
+
|
|
223
|
+
# Some tensors always have the same shape regardless of the model
|
|
224
|
+
self.input_ids = torch.empty((1, self.cache.max_batch_tokens), dtype=torch.int32, device=self.device)
|
|
225
|
+
self.position_ids = torch.empty((1, self.cache.max_batch_tokens), dtype=torch.int32, device=self.device)
|
|
226
|
+
self.cumulative_seqlens_q = torch.empty(
|
|
227
|
+
(self.cache.max_batch_tokens + 1,), dtype=torch.int32, device=self.device
|
|
228
|
+
)
|
|
229
|
+
self.max_seqlen_q = 0
|
|
230
|
+
self.logits_indices = torch.empty((self.cache.max_batch_tokens,), dtype=torch.int32, device=self.device)
|
|
231
|
+
self.output_ids = torch.empty((self.cache.max_batch_tokens,), dtype=torch.int32, device=self.device)
|
|
232
|
+
|
|
233
|
+
# For some kwargs, we have a dict of tensors with as many items as there are attention types
|
|
234
|
+
self.cumulative_seqlens_k: dict[str, torch.Tensor] = {}
|
|
235
|
+
if self.cache.num_full_attention_groups:
|
|
236
|
+
self.cumulative_seqlens_k["full_attention"] = torch.empty(
|
|
237
|
+
(self.cache.max_batch_tokens + 1,), dtype=torch.int32, device=self.device
|
|
238
|
+
)
|
|
239
|
+
if self.cache.num_sliding_attention_groups:
|
|
240
|
+
self.cumulative_seqlens_k["sliding_attention"] = torch.empty(
|
|
241
|
+
(self.cache.max_batch_tokens + 1,), dtype=torch.int32, device=self.device
|
|
242
|
+
)
|
|
243
|
+
self.max_seqlen_k = dict.fromkeys(self.cumulative_seqlens_k.keys(), 0)
|
|
244
|
+
|
|
245
|
+
if attn_mask_is_needed(self.config):
|
|
246
|
+
self.attention_mask = {}
|
|
247
|
+
for layer_type in self.cumulative_seqlens_k.keys():
|
|
248
|
+
self.attention_mask[layer_type] = torch.empty(
|
|
249
|
+
size=(1, 1, self.cache.max_batch_tokens, num_pages + self.cache.max_batch_tokens),
|
|
250
|
+
dtype=self.model_dtype,
|
|
251
|
+
device=self.device,
|
|
252
|
+
)
|
|
253
|
+
else:
|
|
254
|
+
self.attention_mask = None
|
|
255
|
+
|
|
256
|
+
# For other kwargs, we need a list of tensors with as many tensors as there are groups
|
|
257
|
+
self.write_index_storage = [
|
|
258
|
+
torch.empty((self.cache.max_batch_tokens,), dtype=torch.int32, device=self.device)
|
|
259
|
+
for _ in range(self.cache.num_groups)
|
|
260
|
+
]
|
|
261
|
+
self.read_index_storage = [
|
|
262
|
+
torch.empty((num_pages + self.cache.max_batch_tokens), dtype=torch.int32, device=self.device)
|
|
263
|
+
for _ in range(self.cache.num_groups)
|
|
264
|
+
]
|
|
265
|
+
# For read index, the +T is because there are -1 for seqlen_q when model uses a sliding window
|
|
266
|
+
|
|
267
|
+
@traced
|
|
268
|
+
@torch.no_grad()
|
|
269
|
+
def reset_static_tensors(self, full_reset: bool = False) -> None:
|
|
270
|
+
"""Reset static tensors for the next batch. For efficiency, this only resets the portions of tensors that were
|
|
271
|
+
actually used in the previous batch, using the attributes actual_query_length, actual_key_length, and
|
|
272
|
+
actual_batch_size. If a (full_reset) is requested, the entire tensor storage is reset.
|
|
273
|
+
"""
|
|
274
|
+
# Compute the slice to reset
|
|
275
|
+
q_len = self.write_index_storage[0].size(-1) if full_reset else self.actual_query_length
|
|
276
|
+
k_len = self.read_index_storage[0].size(-1) if full_reset else self.actual_key_length
|
|
277
|
+
b_size = self.write_index_storage[0].size(0) if full_reset else self.actual_batch_size
|
|
278
|
+
|
|
279
|
+
# Reset the attributes that always have the same shape
|
|
280
|
+
self.input_ids[:, :q_len].zero_()
|
|
281
|
+
self.position_ids[:, :q_len].zero_()
|
|
282
|
+
self.cumulative_seqlens_q[: b_size + 1].zero_()
|
|
283
|
+
self.max_seqlen_q = 0
|
|
284
|
+
self.logits_indices[:q_len].fill_(-1)
|
|
285
|
+
self.output_ids[:q_len].fill_(-1)
|
|
286
|
+
|
|
287
|
+
# Reset the attributes that are either tensors or dict of tensors
|
|
288
|
+
for layer_type in self.cumulative_seqlens_k:
|
|
289
|
+
self.cumulative_seqlens_k[layer_type][: b_size + 1].zero_()
|
|
290
|
+
self.max_seqlen_k[layer_type] = 0
|
|
291
|
+
if self.attention_mask is not None:
|
|
292
|
+
self.attention_mask[layer_type][:, :, :q_len, :k_len].fill_(torch.finfo(self.model_dtype).min)
|
|
293
|
+
|
|
294
|
+
# Reset the attributes that are lists of tensors
|
|
295
|
+
for i in range(self.cache.num_groups):
|
|
296
|
+
self.write_index_storage[i][:q_len].fill_(-2) # -1 is used to let the cache where new states go
|
|
297
|
+
self.read_index_storage[i][: q_len + k_len].fill_(-2) # same
|
|
298
|
+
|
|
299
|
+
@traced
|
|
300
|
+
def prepare_batch_tensors(self, requests_in_batch: list[RequestState]) -> None:
|
|
301
|
+
"""Prepare tensors and metadata for the next model forward pass, using the given requests as data. This method:
|
|
302
|
+
|
|
303
|
+
1. Resets the static tensors from the previous batch
|
|
304
|
+
2. Iterates through requests to accumulate input_ids, position_ids, and sequence lengths
|
|
305
|
+
3. Extends read/write indices for cache management
|
|
306
|
+
4. Builds attention masks if needed (for eager/SDPA implementations)
|
|
307
|
+
5. Converts accumulated lists to tensors and copies them to static storage
|
|
308
|
+
|
|
309
|
+
This method also modifies the `position_offset` attribute of each request to track progress and adds a
|
|
310
|
+
temporary token at the end of the requests for which there will a new token.
|
|
311
|
+
"""
|
|
312
|
+
# Keep track of this requests in the batch, which will be useful to update the batch later
|
|
313
|
+
self.requests_in_batch = requests_in_batch
|
|
314
|
+
if not self.requests_in_batch:
|
|
315
|
+
raise ValueError("No requests in batch")
|
|
316
|
+
|
|
317
|
+
# Reset the static tensors used for storage
|
|
318
|
+
self.reset_static_tensors() # FIXME: why does this make the generation faster?
|
|
319
|
+
# Reset accumulators
|
|
320
|
+
self.actual_query_length = 0
|
|
321
|
+
self.actual_key_length = 0
|
|
322
|
+
self.actual_batch_size = 0
|
|
323
|
+
|
|
324
|
+
# Prepare accumulators
|
|
325
|
+
input_ids = []
|
|
326
|
+
position_ids = []
|
|
327
|
+
cumulative_seqlens_q = [0]
|
|
328
|
+
logits_indices = []
|
|
329
|
+
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k.keys()}
|
|
330
|
+
read_index = [[] for _ in range(self.cache.num_groups)]
|
|
331
|
+
write_index = [[] for _ in range(self.cache.num_groups)]
|
|
332
|
+
|
|
333
|
+
# Go through all the requests in the batch
|
|
334
|
+
for state in self.requests_in_batch:
|
|
335
|
+
# First we retrieve the lengths related to the request
|
|
336
|
+
past_length = state.position_offset
|
|
337
|
+
query_length = len(state.tokens_to_process)
|
|
338
|
+
seqlens_k = self.cache.get_seqlens_k(past_length, query_length)
|
|
339
|
+
|
|
340
|
+
# Then we update the total lengths that are used for slicing
|
|
341
|
+
self.actual_query_length += query_length
|
|
342
|
+
# total_key_length is used to slice the keys so we need to take the max of all the key lengths
|
|
343
|
+
self.actual_key_length += max(seqlens_k.values())
|
|
344
|
+
self.actual_batch_size += 1
|
|
345
|
+
# And the attribute tracking the position in the request object
|
|
346
|
+
state.position_offset += query_length
|
|
347
|
+
|
|
348
|
+
# Then we accumulate for the object used in the kwargs
|
|
349
|
+
input_ids.extend(state.tokens_to_process)
|
|
350
|
+
position_ids.extend(range(past_length, past_length + query_length))
|
|
351
|
+
cumulative_seqlens_q.append(cumulative_seqlens_q[-1] + query_length)
|
|
352
|
+
self.max_seqlen_q = max(self.max_seqlen_q, query_length)
|
|
353
|
+
|
|
354
|
+
# Accumulate the key sequence lengths for the current request
|
|
355
|
+
for layer_type, layer_type_seqlen_k in seqlens_k.items():
|
|
356
|
+
cumulative_seqlens_k[layer_type].append(cumulative_seqlens_k[layer_type][-1] + layer_type_seqlen_k)
|
|
357
|
+
self.max_seqlen_k[layer_type] = max(self.max_seqlen_k[layer_type], layer_type_seqlen_k)
|
|
358
|
+
|
|
359
|
+
# We extend the read and write indices for the cache
|
|
360
|
+
self.cache.extend_read_and_write_indices(
|
|
361
|
+
state.request_id, past_length, query_length, read_index, write_index
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# If the request has no remaining prefill tokens, it means the next token prediction is relevant
|
|
365
|
+
if not state.remaining_prefill_tokens:
|
|
366
|
+
logits_indices.append(cumulative_seqlens_q[-1] - 1)
|
|
367
|
+
state.generated_tokens.append(TMP_TOKEN_ID)
|
|
368
|
+
|
|
369
|
+
# When looping over request is done, we can build the actual tensors. This is faster than modifying the static
|
|
370
|
+
# tensors inside the loop.
|
|
371
|
+
to_tensor = partial(torch.tensor, dtype=torch.int32, device=self.device)
|
|
372
|
+
|
|
373
|
+
# Those kwargs always have the same type regardless of the model
|
|
374
|
+
self.input_ids[:, : len(input_ids)] = to_tensor(input_ids)
|
|
375
|
+
self.position_ids[:, : len(position_ids)] = to_tensor(position_ids)
|
|
376
|
+
self.cumulative_seqlens_q[: len(cumulative_seqlens_q)] = to_tensor(cumulative_seqlens_q)
|
|
377
|
+
self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices)
|
|
378
|
+
self.total_seqlen_q = cumulative_seqlens_q[-1]
|
|
379
|
+
|
|
380
|
+
# Those kwargs are either dict of tensors or tensors, so we need to handle both cases
|
|
381
|
+
for layer_type, layer_type_seqlens_k in cumulative_seqlens_k.items():
|
|
382
|
+
self.cumulative_seqlens_k[layer_type][: len(layer_type_seqlens_k)] = to_tensor(layer_type_seqlens_k)
|
|
383
|
+
if self.attention_mask is not None:
|
|
384
|
+
build_attention_mask(
|
|
385
|
+
attention_mask=self.attention_mask[layer_type],
|
|
386
|
+
cumulative_seqlens_q=cumulative_seqlens_q,
|
|
387
|
+
cumulative_seqlens_k=layer_type_seqlens_k,
|
|
388
|
+
sliding_window=self.sliding_window if layer_type == "sliding_attention" else 1,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# The index only contain references to the storage tensors, so we update the storage and their references
|
|
392
|
+
self.read_index = []
|
|
393
|
+
self.write_index = []
|
|
394
|
+
for i, group_read_indices, group_write_indices in zip(count(), read_index, write_index):
|
|
395
|
+
self.read_index_storage[i][: len(group_read_indices)] = to_tensor(group_read_indices)
|
|
396
|
+
self.write_index_storage[i][: len(group_write_indices)] = to_tensor(group_write_indices)
|
|
397
|
+
self.actual_index_sizes[i] = (len(group_read_indices), len(group_write_indices))
|
|
398
|
+
|
|
399
|
+
def get_model_kwargs(self, padded_q_size: int = 0, padded_kv_cache_size: int = 0) -> dict[str, Any]:
|
|
400
|
+
"""Get model keyword arguments for the current batch, eventually padding the query dimension to (padded_q_size)
|
|
401
|
+
and the keys/values dimension to (padded_kv_cache_size). The padding is only useful if we want static shapes,
|
|
402
|
+
like when using cuda graphs AND only activated if both Q and KV are padded."""
|
|
403
|
+
# Compute the slice to return, with the given padding if we are using cuda graphs
|
|
404
|
+
use_padding = padded_q_size > 0 and padded_kv_cache_size > 0
|
|
405
|
+
q_len = padded_q_size if use_padding else self.actual_query_length
|
|
406
|
+
b_size = padded_q_size if use_padding else self.actual_batch_size
|
|
407
|
+
# If there is padding, the size of the KV is the nb of padded Q tokens + the size padded of the padded KV cache
|
|
408
|
+
padded_kv_size = padded_q_size + padded_kv_cache_size
|
|
409
|
+
|
|
410
|
+
# Prepare the kwargs, the attributes that are either tensors or dict of tensors are initialized to empty dicts
|
|
411
|
+
kwargs = PagedAttentionArgs(
|
|
412
|
+
input_ids=self.input_ids[:, :q_len],
|
|
413
|
+
position_ids=self.position_ids[:, :q_len],
|
|
414
|
+
cu_seq_lens_q=self.cumulative_seqlens_q[: b_size + 1],
|
|
415
|
+
max_seqlen_q=self.max_seqlen_q,
|
|
416
|
+
logits_indices=self.logits_indices[:q_len],
|
|
417
|
+
cu_seq_lens_k={},
|
|
418
|
+
max_seqlen_k={},
|
|
419
|
+
attention_mask={},
|
|
420
|
+
read_index=[],
|
|
421
|
+
write_index=[],
|
|
422
|
+
cache=self.cache,
|
|
423
|
+
use_cache=False,
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# If we use constant-sized slicing, there are some "padding" queries tokens which FA has some issues with. In
|
|
427
|
+
# some models like Qwen3-4B-Instruct-2507, if we don't include these tokens in cumulative_seqlens_q, there are
|
|
428
|
+
# some NaNs in the output logits even for non-padded tokens.
|
|
429
|
+
if use_padding:
|
|
430
|
+
self.max_seqlen_q = max(self.max_seqlen_q, q_len - self.total_seqlen_q)
|
|
431
|
+
kwargs.max_seqlen_q = self.max_seqlen_q
|
|
432
|
+
self.cumulative_seqlens_q[self.actual_batch_size + 1 :] = q_len
|
|
433
|
+
# FIXME: is there another way to avoid this? It has a very slight impact on performance (~5 tok/s)
|
|
434
|
+
|
|
435
|
+
# For the attributes that are lists of tensors, we construct list of tensor references
|
|
436
|
+
for i, (read_index_size, write_index_size) in enumerate(self.actual_index_sizes):
|
|
437
|
+
read_index_size = padded_kv_size if use_padding else read_index_size
|
|
438
|
+
write_index_size = padded_q_size if use_padding else write_index_size
|
|
439
|
+
kwargs.read_index.append(self.read_index_storage[i][:read_index_size])
|
|
440
|
+
kwargs.write_index.append(self.write_index_storage[i][:write_index_size])
|
|
441
|
+
|
|
442
|
+
# For the attributes that are dict of tensors, we replace the dict with a tensor if there is only one entry
|
|
443
|
+
layer_types = list(self.cumulative_seqlens_k.keys())
|
|
444
|
+
if len(layer_types) > 1:
|
|
445
|
+
kwargs.max_seqlen_k: dict[str, int] = {}
|
|
446
|
+
kwargs.cu_seq_lens_k: dict[str, torch.Tensor] = {}
|
|
447
|
+
kwargs.attention_mask: dict[str, torch.Tensor] = {}
|
|
448
|
+
for layer_type, seqlens_k in self.cumulative_seqlens_k.items():
|
|
449
|
+
kwargs.cu_seq_lens_k[layer_type] = seqlens_k[: b_size + 1]
|
|
450
|
+
kwargs.max_seqlen_k[layer_type] = self.max_seqlen_k[layer_type]
|
|
451
|
+
if self.attention_mask is not None:
|
|
452
|
+
k_len = padded_kv_size if use_padding else seqlens_k[b_size]
|
|
453
|
+
kwargs.attention_mask[layer_type] = self.attention_mask[layer_type][..., :q_len, :k_len]
|
|
454
|
+
else:
|
|
455
|
+
layer_type = layer_types[0]
|
|
456
|
+
kwargs.cu_seq_lens_k = self.cumulative_seqlens_k[layer_type][: b_size + 1]
|
|
457
|
+
kwargs.max_seqlen_k = self.max_seqlen_k[layer_type]
|
|
458
|
+
if self.attention_mask is not None:
|
|
459
|
+
k_len = padded_kv_size if use_padding else self.cumulative_seqlens_k[layer_type][b_size]
|
|
460
|
+
kwargs.attention_mask = self.attention_mask[layer_type][..., :q_len, :k_len]
|
|
461
|
+
|
|
462
|
+
if self.attention_mask is None:
|
|
463
|
+
kwargs.attention_mask = None
|
|
464
|
+
return kwargs.asdict() # TODO: this is imperfect, check if there is no better way to juggle dict / dataclass
|
|
@@ -17,11 +17,18 @@ from enum import Enum
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from ...utils import is_torch_xpu_available
|
|
20
|
+
from ...utils import is_psutil_available, is_torch_xpu_available
|
|
21
21
|
from ...utils.logging import logging
|
|
22
22
|
from ...utils.metrics import traced
|
|
23
23
|
|
|
24
24
|
|
|
25
|
+
if is_psutil_available():
|
|
26
|
+
import psutil
|
|
27
|
+
|
|
28
|
+
# This is a temporary token ID used to represent a token that is not yet generated
|
|
29
|
+
TMP_TOKEN_ID = -1
|
|
30
|
+
|
|
31
|
+
|
|
25
32
|
# We centralize the logger here to coordinate between logging and progress bar
|
|
26
33
|
logger = logging.getLogger("ContinuousBatchingLogger")
|
|
27
34
|
|
|
@@ -49,9 +56,19 @@ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
|
|
|
49
56
|
reserved_memory = 0 # MPS does not track reserved separately
|
|
50
57
|
else:
|
|
51
58
|
device = torch.device("cpu")
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
59
|
+
if is_psutil_available():
|
|
60
|
+
total_memory = psutil.virtual_memory().total
|
|
61
|
+
allocated_memory = psutil.Process().memory_info().rss
|
|
62
|
+
reserved_memory = allocated_memory
|
|
63
|
+
else:
|
|
64
|
+
logger.error(
|
|
65
|
+
"Cannot get memory breakdown on CPU without psutil: returning 0 for all memory values. Please install "
|
|
66
|
+
"psutil to get an actual memory breakdown."
|
|
67
|
+
)
|
|
68
|
+
total_memory = 0
|
|
69
|
+
reserved_memory = 0
|
|
70
|
+
allocated_memory = 0
|
|
71
|
+
|
|
55
72
|
return device, total_memory, reserved_memory, allocated_memory
|
|
56
73
|
|
|
57
74
|
|
|
@@ -79,6 +96,7 @@ class GenerationOutput:
|
|
|
79
96
|
error (Optional[str]): Any error message associated with the request. When None, the request was successful.
|
|
80
97
|
status (RequestStatus): The status of the request.
|
|
81
98
|
created_time (float): The time the request was created.
|
|
99
|
+
lifespan (tuple[float, float]): The time the request was no longer pending and the time the request finished.
|
|
82
100
|
"""
|
|
83
101
|
|
|
84
102
|
request_id: str
|
|
@@ -88,6 +106,7 @@ class GenerationOutput:
|
|
|
88
106
|
error: str | None = None
|
|
89
107
|
status: RequestStatus = RequestStatus.PENDING
|
|
90
108
|
created_time: float = field(default_factory=time.perf_counter)
|
|
109
|
+
lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished)
|
|
91
110
|
timestamps: list[float] | None = None # Timestamps of the generated tokens
|
|
92
111
|
|
|
93
112
|
def is_finished(self) -> bool:
|
|
@@ -110,7 +129,7 @@ class RequestState:
|
|
|
110
129
|
position_offset (int): The current position in the sequence for position_ids.
|
|
111
130
|
status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT,
|
|
112
131
|
SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED
|
|
113
|
-
max_new_tokens (int): The maximum number of new tokens to generate.
|
|
132
|
+
max_new_tokens (int | None): The maximum number of new tokens to generate.
|
|
114
133
|
eos_token_id (int): The ID of the end-of-sequence token.
|
|
115
134
|
streaming (bool): Whether to stream tokens as they're generated
|
|
116
135
|
created_time (float): The time the request was created.
|
|
@@ -124,13 +143,13 @@ class RequestState:
|
|
|
124
143
|
record_timestamps: bool = False # Whether to record timestamps for the generated tokens
|
|
125
144
|
num_children: int = 0 # Number of children requests
|
|
126
145
|
# Internal fields
|
|
127
|
-
tokens_to_process: list[int]
|
|
146
|
+
tokens_to_process: list[int] = field(default_factory=list) # Tokens IDs currently being processed
|
|
128
147
|
remaining_prefill_tokens: list[int] = field(default_factory=list) # For split requests, prefill left to process
|
|
129
148
|
generated_tokens: list[int] = field(default_factory=list) # Generated tokens
|
|
130
149
|
allocated_blocks: int = 0 # Number of blocks allocated to the request
|
|
131
150
|
position_offset: int = 0 # Current position in the sequence for position_ids
|
|
132
151
|
_status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property
|
|
133
|
-
max_new_tokens: int = 20 # Maximum number of new tokens to generate
|
|
152
|
+
max_new_tokens: int | None = 20 # Maximum number of new tokens to generate. None means no limit. Default to 20.
|
|
134
153
|
eos_token_id: int = -1 # ID of the end-of-sequence token
|
|
135
154
|
streaming: bool = False # Whether to stream tokens as they're generated
|
|
136
155
|
created_time: float = field(default_factory=time.perf_counter) # Time the request was created
|
|
@@ -139,6 +158,11 @@ class RequestState:
|
|
|
139
158
|
_timestamps: list[float] = field(default_factory=list) # Timestamps of the generated tokens
|
|
140
159
|
_true_initial_tokens: int = 0 # The true number of initial tokens, useful when soft resetting requests
|
|
141
160
|
# TODO: remove the attribute above to _num_initial_tokens once initial_tokens is renamed
|
|
161
|
+
_new_tokens_limit: int = 2147483647 # An int to check the max number of new tokens w/out always comparing w/ None
|
|
162
|
+
|
|
163
|
+
def __post_init__(self):
|
|
164
|
+
# If no max length is set, we set an absurdly high value which will never be reached
|
|
165
|
+
self._new_tokens_limit = 2147483647 if self.max_new_tokens is None else self.max_new_tokens
|
|
142
166
|
|
|
143
167
|
@property
|
|
144
168
|
def status(self) -> RequestStatus:
|
|
@@ -193,18 +217,23 @@ class RequestState:
|
|
|
193
217
|
if self.record_timestamps:
|
|
194
218
|
self._timestamps.append(time.perf_counter())
|
|
195
219
|
|
|
220
|
+
# Stop if we reached an EOS token
|
|
196
221
|
is_eos = token_id == self.eos_token_id and self.eos_token_id != -1
|
|
197
|
-
|
|
222
|
+
current_len = self.generated_len() - 1 # do not count the temporary token
|
|
198
223
|
|
|
199
|
-
#
|
|
224
|
+
# Replace the temporary token if we're not finishing due to max length
|
|
200
225
|
# (EOS tokens should still be added to the output)
|
|
201
|
-
if
|
|
202
|
-
self.generated_tokens
|
|
203
|
-
|
|
204
|
-
|
|
226
|
+
if is_eos or (current_len < self._new_tokens_limit):
|
|
227
|
+
self.generated_tokens[-1] = token_id
|
|
228
|
+
current_len += 1
|
|
229
|
+
else:
|
|
230
|
+
logger.warning(f"Request {self.request_id} generated a useless token: {token_id}")
|
|
231
|
+
self.generated_tokens.pop()
|
|
232
|
+
|
|
233
|
+
if is_eos or current_len >= self._new_tokens_limit:
|
|
205
234
|
self.status = RequestStatus.FINISHED
|
|
206
235
|
return True
|
|
207
|
-
return False
|
|
236
|
+
return False # We still need to process more tokens
|
|
208
237
|
|
|
209
238
|
def __repr__(self):
|
|
210
239
|
msg = [
|
|
@@ -222,16 +251,20 @@ class RequestState:
|
|
|
222
251
|
|
|
223
252
|
def to_generation_output(self):
|
|
224
253
|
"""Convert the request state to a GenerationOutput object."""
|
|
254
|
+
if self.generated_tokens and self.generated_tokens[-1] == TMP_TOKEN_ID:
|
|
255
|
+
self.generated_tokens.pop()
|
|
225
256
|
if self._true_initial_tokens:
|
|
226
257
|
self.generated_tokens = self.initial_tokens[self._true_initial_tokens :] + self.generated_tokens
|
|
227
258
|
self.initial_tokens = self.initial_tokens[: self._true_initial_tokens]
|
|
228
259
|
return GenerationOutput(
|
|
229
260
|
request_id=self.request_id,
|
|
230
261
|
prompt_ids=self.initial_tokens,
|
|
231
|
-
status=self.status,
|
|
232
262
|
generated_tokens=self.generated_tokens,
|
|
233
263
|
logprobs=[],
|
|
234
264
|
error=self.error,
|
|
265
|
+
status=self.status,
|
|
266
|
+
created_time=self.created_time,
|
|
267
|
+
lifespan=self.lifespan,
|
|
235
268
|
timestamps=self.timestamps,
|
|
236
269
|
)
|
|
237
270
|
|
|
@@ -253,7 +286,7 @@ class RequestState:
|
|
|
253
286
|
streaming=self.streaming,
|
|
254
287
|
created_time=t,
|
|
255
288
|
lifespan=(t, -1),
|
|
256
|
-
_timestamps=
|
|
289
|
+
_timestamps=[],
|
|
257
290
|
error=self.error,
|
|
258
291
|
record_timestamps=self.record_timestamps,
|
|
259
292
|
)
|
|
@@ -263,13 +296,17 @@ class RequestState:
|
|
|
263
296
|
"""Creates an equivalent new request by removing the generated tokens and adding them to the initial prompt. The
|
|
264
297
|
created request has THE SAME request_id. Notably, we can retrieve the original request from the created one with
|
|
265
298
|
the _true_initial_tokens attribute."""
|
|
299
|
+
# Remove the temporary token if it exists
|
|
300
|
+
if self.generated_tokens and self.generated_tokens[-1] == TMP_TOKEN_ID:
|
|
301
|
+
self.generated_tokens.pop()
|
|
302
|
+
max_new_tokens = None if self.max_new_tokens is None else (self.max_new_tokens - len(self.generated_tokens))
|
|
266
303
|
new_state = RequestState(
|
|
267
304
|
request_id=self.request_id,
|
|
268
305
|
initial_tokens=self.initial_tokens + self.generated_tokens,
|
|
269
306
|
num_children=self.num_children,
|
|
270
307
|
record_timestamps=self.record_timestamps,
|
|
271
308
|
tokens_to_process=self.initial_tokens + self.generated_tokens,
|
|
272
|
-
max_new_tokens=
|
|
309
|
+
max_new_tokens=max_new_tokens,
|
|
273
310
|
eos_token_id=self.eos_token_id,
|
|
274
311
|
streaming=self.streaming,
|
|
275
312
|
)
|