transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -14,11 +14,9 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
import queue
|
|
16
16
|
import threading
|
|
17
|
+
from abc import abstractmethod
|
|
17
18
|
from collections.abc import Generator
|
|
18
19
|
from contextlib import contextmanager
|
|
19
|
-
from dataclasses import dataclass
|
|
20
|
-
from functools import partial
|
|
21
|
-
from itertools import count
|
|
22
20
|
from math import ceil
|
|
23
21
|
from time import perf_counter
|
|
24
22
|
|
|
@@ -29,10 +27,11 @@ from tqdm.contrib.logging import logging_redirect_tqdm
|
|
|
29
27
|
|
|
30
28
|
from ...configuration_utils import PretrainedConfig
|
|
31
29
|
from ...generation.configuration_utils import CompileConfig, GenerationConfig
|
|
32
|
-
from ...generation.logits_process import
|
|
30
|
+
from ...generation.logits_process import LogitsProcessorList
|
|
33
31
|
from ...utils.logging import logging
|
|
34
32
|
from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
|
|
35
33
|
from .cache import PagedAttentionCache
|
|
34
|
+
from .input_ouputs import ContinuousBatchingIOs, attn_mask_is_needed
|
|
36
35
|
from .requests import GenerationOutput, RequestState, RequestStatus, logger
|
|
37
36
|
from .scheduler import SCHEDULER_MAPPING, FIFOScheduler, Scheduler
|
|
38
37
|
|
|
@@ -69,109 +68,19 @@ def pad_by_intervals(size: int, max_value: int, nb_intervals: int) -> int:
|
|
|
69
68
|
return min(padded, max_value)
|
|
70
69
|
|
|
71
70
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
71
|
+
# We cannot use `PreTrainedModel` for circular import reasons, so this helps keep track of the basic types
|
|
72
|
+
class ProtoPretrainedModel(nn.Module):
|
|
73
|
+
config: PretrainedConfig
|
|
74
|
+
dtype: torch.dtype
|
|
75
|
+
device: torch.device
|
|
75
76
|
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def set_attn_implementation(self, attn_implementation: str) -> None:
|
|
79
|
+
pass
|
|
76
80
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
cumulative_seqlens_k: list[int],
|
|
81
|
-
sliding_window: int = 1,
|
|
82
|
-
) -> None:
|
|
83
|
-
"""Builds an attention mask inplace using the cumulative seqlens of the query and key. If given a sliding window, it
|
|
84
|
-
will also apply a sliding window mask on top. The attention mask is not boolean, it uses zeroes and -inf (or its
|
|
85
|
-
equivalent) so it's more of an attention score bias tensor.
|
|
86
|
-
The attention mask is a block-diagonal matrix, with each block an attention mask for a single query-key pair.
|
|
87
|
-
Each of those block is built from a causal mask and, if there is a sliding window, a sliding window mask.
|
|
88
|
-
|
|
89
|
-
An example is represented below, with seqlen_k = 8, seqlen_q = 4 and sliding_window = 6:
|
|
90
|
-
|
|
91
|
-
CAUSAL MASK:
|
|
92
|
-
|
|
93
|
-
█ █ █ █ █ ░ ░ ░
|
|
94
|
-
█ █ █ █ █ █ ░ ░
|
|
95
|
-
█ █ █ █ █ █ █ ░
|
|
96
|
-
█ █ █ █ █ █ █ █
|
|
97
|
-
|
|
98
|
-
SLIDING WINDOW MASK:
|
|
99
|
-
┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 8 - 4 - 6 = -2 offset to the left
|
|
100
|
-
<─┴─>
|
|
101
|
-
░ █ | █ █ █ █ █ █ █ █
|
|
102
|
-
░ ░ | █ █ █ █ █ █ █ █
|
|
103
|
-
░ ░ | ░ █ █ █ █ █ █ █
|
|
104
|
-
░ ░ | ░ ░ █ █ █ █ █ █
|
|
105
|
-
|
|
106
|
-
ATTENTION MASK (sum of causal and sliding window masks):
|
|
107
|
-
|
|
108
|
-
█ █ █ █ █ ░ ░ ░
|
|
109
|
-
█ █ █ █ █ █ ░ ░
|
|
110
|
-
░ █ █ █ █ █ █ ░
|
|
111
|
-
░ ░ █ █ █ █ █ █
|
|
112
|
-
|
|
113
|
-
Another example with seqlen_k = 5, seqlen_q = 3 and sliding_window = 2:
|
|
114
|
-
|
|
115
|
-
CAUSAL MASK:
|
|
116
|
-
|
|
117
|
-
█ █ █ ░ ░
|
|
118
|
-
█ █ █ █ ░
|
|
119
|
-
█ █ █ █ █
|
|
120
|
-
|
|
121
|
-
SLIDING WINDOW MASK:
|
|
122
|
-
┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 5 - 3 - 2 = 0 offset to the left
|
|
123
|
-
<┴>
|
|
124
|
-
| ░ █ █ █ █
|
|
125
|
-
| ░ ░ █ █ █
|
|
126
|
-
| ░ ░ ░ █ █
|
|
127
|
-
|
|
128
|
-
ATTENTION MASK (sum of causal and sliding window masks):
|
|
129
|
-
|
|
130
|
-
░ █ █ ░ ░
|
|
131
|
-
░ ░ █ █ ░
|
|
132
|
-
░ ░ ░ █ █
|
|
133
|
-
|
|
134
|
-
"""
|
|
135
|
-
min_value = torch.finfo(attention_mask.dtype).min
|
|
136
|
-
for i in range(len(cumulative_seqlens_q) - 1):
|
|
137
|
-
seqlen_q = cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]
|
|
138
|
-
seqlen_k = cumulative_seqlens_k[i + 1] - cumulative_seqlens_k[i]
|
|
139
|
-
if seqlen_q < seqlen_k and seqlen_q >= 1:
|
|
140
|
-
causal_diagonal = seqlen_k - seqlen_q + 1
|
|
141
|
-
else:
|
|
142
|
-
causal_diagonal = 1
|
|
143
|
-
query_range = slice(cumulative_seqlens_q[i], cumulative_seqlens_q[i + 1])
|
|
144
|
-
key_range = slice(cumulative_seqlens_k[i], cumulative_seqlens_k[i + 1])
|
|
145
|
-
# Apply causal mask
|
|
146
|
-
minus_inf = torch.full(
|
|
147
|
-
attention_mask[..., query_range, key_range].shape,
|
|
148
|
-
min_value,
|
|
149
|
-
dtype=attention_mask.dtype,
|
|
150
|
-
device=attention_mask.device,
|
|
151
|
-
)
|
|
152
|
-
masked = torch.triu(minus_inf, diagonal=causal_diagonal)
|
|
153
|
-
# Apply sliding window mask if needed
|
|
154
|
-
if sliding_window > 1:
|
|
155
|
-
sliding_diagonal = seqlen_k - seqlen_q - sliding_window
|
|
156
|
-
masked += torch.tril(minus_inf, diagonal=sliding_diagonal)
|
|
157
|
-
# Replace in attention mask
|
|
158
|
-
attention_mask[..., query_range, key_range] = masked
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
@dataclass
|
|
162
|
-
class PagedAttentionArgs:
|
|
163
|
-
input_ids: torch.Tensor
|
|
164
|
-
attention_mask: torch.Tensor | None
|
|
165
|
-
position_ids: torch.Tensor
|
|
166
|
-
cumulative_seqlens_q: torch.Tensor
|
|
167
|
-
cumulative_seqlens_k: torch.Tensor
|
|
168
|
-
max_seqlen_q: int
|
|
169
|
-
max_seqlen_k: int
|
|
170
|
-
write_index: list[torch.Tensor]
|
|
171
|
-
read_index: list[torch.Tensor]
|
|
172
|
-
logits_indices: torch.Tensor
|
|
173
|
-
cache: PagedAttentionCache
|
|
174
|
-
use_cache: bool = False
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def _get_logits_processor(self, generation_config: GenerationConfig) -> LogitsProcessorList:
|
|
83
|
+
pass
|
|
175
84
|
|
|
176
85
|
|
|
177
86
|
# Continuous Batch Processor (Internal Logic)
|
|
@@ -238,160 +147,14 @@ class ContinuousBatchProcessor:
|
|
|
238
147
|
self.max_batch_tokens = cache.max_batch_tokens
|
|
239
148
|
self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens)
|
|
240
149
|
|
|
241
|
-
# Setup
|
|
242
|
-
self.
|
|
243
|
-
self.actual_key_length = 0 # This is the actual number of keys/values tokens in the batch
|
|
244
|
-
self.actual_batch_size = 0 # This is the actual number of requests in the batch
|
|
245
|
-
self.actual_index_sizes = [(0, 0) for _ in range(cache.num_groups)]
|
|
246
|
-
self.setup_static_tensors(cache.num_groups)
|
|
247
|
-
|
|
248
|
-
@traced(standalone=True)
|
|
249
|
-
def setup_static_tensors(self, num_groups: int) -> None:
|
|
250
|
-
"""Setup the static tensors that are used for storage during the generation step. No other tensor will be
|
|
251
|
-
allowed for the inputs or the outputs of the generation step."""
|
|
252
|
-
self.num_pages = self.cache.num_blocks * self.cache.block_size
|
|
253
|
-
self.tensor_metadata = {"dtype": torch.int32, "device": self.model_device}
|
|
254
|
-
|
|
255
|
-
# Some tensors always have the same shape regardless of the model
|
|
256
|
-
self.input_ids = torch.empty((1, self.max_batch_tokens), **self.tensor_metadata)
|
|
257
|
-
self.position_ids = torch.empty((1, self.max_batch_tokens), **self.tensor_metadata)
|
|
258
|
-
self.cumulative_seqlens_q = torch.empty((self.max_batch_tokens + 1,), **self.tensor_metadata)
|
|
259
|
-
self.max_seqlen_q = 0
|
|
260
|
-
self.logits_indices = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
|
|
261
|
-
self.output_ids = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
|
|
262
|
-
|
|
263
|
-
# For some kwargs, we have a dict of tensors with as many items as there are attention types
|
|
264
|
-
layer_types = getattr(self.config, "layer_types", None)
|
|
265
|
-
if layer_types is None:
|
|
266
|
-
sliding_window = getattr(self.config, "sliding_window", 1)
|
|
267
|
-
layer_types = ["full_attention"] if sliding_window in [1, None] else ["sliding_attention"]
|
|
268
|
-
layer_types = list(set(layer_types))
|
|
269
|
-
|
|
270
|
-
self.cumulative_seqlens_k = {
|
|
271
|
-
l_type: torch.empty((self.max_batch_tokens + 1), **self.tensor_metadata) for l_type in layer_types
|
|
272
|
-
}
|
|
273
|
-
self.max_seqlen_k = dict.fromkeys(layer_types, 0)
|
|
274
|
-
|
|
275
|
-
if attn_mask_is_needed(self.config):
|
|
276
|
-
attn_mask_kwargs = {
|
|
277
|
-
"size": (1, 1, self.max_batch_tokens, self.num_pages + self.max_batch_tokens),
|
|
278
|
-
"dtype": self.model_dtype,
|
|
279
|
-
"device": self.model_device,
|
|
280
|
-
}
|
|
281
|
-
self.attention_mask = {layer_type: torch.empty(**attn_mask_kwargs) for layer_type in layer_types}
|
|
282
|
-
else:
|
|
283
|
-
self.attention_mask = None
|
|
284
|
-
|
|
285
|
-
# For other kwargs, we need a list of tensors with as many tensors as there are groups
|
|
286
|
-
self.write_index_storage = [
|
|
287
|
-
torch.empty((self.max_batch_tokens,), **self.tensor_metadata) for _ in range(num_groups)
|
|
288
|
-
]
|
|
289
|
-
self.read_index_storage = [
|
|
290
|
-
torch.empty((self.num_pages + self.max_batch_tokens), **self.tensor_metadata) for _ in range(num_groups)
|
|
291
|
-
]
|
|
292
|
-
# For read index, the +T is because there are -1 for seqlen_q when model uses a sliding window
|
|
293
|
-
|
|
294
|
-
# After allocating empty tensors, we reset them to the right value
|
|
295
|
-
self.reset_static_tensors(full_reset=True)
|
|
296
|
-
|
|
297
|
-
@traced
|
|
298
|
-
@torch.no_grad()
|
|
299
|
-
def reset_static_tensors(self, full_reset: bool = False) -> None:
|
|
300
|
-
"""Reset static tensors for the next batch. In between batches, reset only the parts that were used in the last
|
|
301
|
-
batch, but for initialisation, we can reset everything using the (full_reset) flag."""
|
|
302
|
-
# Compute the slice to reset
|
|
303
|
-
q_len = self.write_index_storage[0].size(-1) if full_reset else self.actual_query_length
|
|
304
|
-
k_len = self.read_index_storage[0].size(-1) if full_reset else self.actual_key_length
|
|
305
|
-
b_size = self.write_index_storage[0].size(0) if full_reset else self.actual_batch_size
|
|
306
|
-
|
|
307
|
-
# Reset the attributes that always have the same shape
|
|
308
|
-
self.input_ids[:, :q_len].zero_()
|
|
309
|
-
self.position_ids[:, :q_len].zero_()
|
|
310
|
-
self.cumulative_seqlens_q[: b_size + 1].zero_()
|
|
311
|
-
self.max_seqlen_q = 0
|
|
312
|
-
self.logits_indices[:q_len].fill_(-1)
|
|
313
|
-
self.output_ids[:q_len].fill_(-1)
|
|
314
|
-
|
|
315
|
-
# Reset the attributes that are either tensors or dict of tensors
|
|
316
|
-
for layer_type in self.cumulative_seqlens_k:
|
|
317
|
-
self.cumulative_seqlens_k[layer_type][: b_size + 1].zero_()
|
|
318
|
-
self.max_seqlen_k[layer_type] = 0
|
|
319
|
-
if self.attention_mask is not None:
|
|
320
|
-
self.attention_mask[layer_type][:, :, :q_len, :k_len].fill_(torch.finfo(self.model_dtype).min)
|
|
321
|
-
|
|
322
|
-
# Reset the attributes that are lists of tensors
|
|
323
|
-
for i in range(self.cache.num_groups):
|
|
324
|
-
self.write_index_storage[i][:q_len].fill_(-2) # -1 is used to let the cache where new states go
|
|
325
|
-
self.read_index_storage[i][: q_len + k_len].fill_(-2) # same
|
|
326
|
-
|
|
327
|
-
def get_model_kwargs(self, padded_q_size: int = 0, padded_kv_cache_size: int = 0) -> PagedAttentionArgs:
|
|
328
|
-
"""Get model keyword arguments for the current batch, eventually padding the query dimension to (padded_q_size)
|
|
329
|
-
and the keys/values dimension to (padded_kv_cache_size). The padding is only useful if we want static shapes,
|
|
330
|
-
like when using cuda graphs AND only activated if both Q and KV are padded."""
|
|
331
|
-
# Compute the slice to return, with the given padding if we are using cuda graphs
|
|
332
|
-
use_padding = padded_q_size > 0 and padded_kv_cache_size > 0
|
|
333
|
-
q_len = padded_q_size if use_padding else self.actual_query_length
|
|
334
|
-
b_size = padded_q_size if use_padding else self.actual_batch_size
|
|
335
|
-
# If there is padding, the size of the KV is the nb of padded Q tokens + the size padded of the padded KV cache
|
|
336
|
-
padded_kv_size = padded_q_size + padded_kv_cache_size
|
|
337
|
-
|
|
338
|
-
# Prepare the kwargs, the attributes that are either tensors or dict of tensors are initialized to empty dicts
|
|
339
|
-
kwargs = {
|
|
340
|
-
"input_ids": self.input_ids[:, :q_len],
|
|
341
|
-
"position_ids": self.position_ids[:, :q_len],
|
|
342
|
-
"cu_seq_lens_q": self.cumulative_seqlens_q[: b_size + 1],
|
|
343
|
-
"max_seqlen_q": self.max_seqlen_q,
|
|
344
|
-
"logits_indices": self.logits_indices[:q_len],
|
|
345
|
-
"cu_seq_lens_k": {},
|
|
346
|
-
"max_seqlen_k": {},
|
|
347
|
-
"attention_mask": {},
|
|
348
|
-
"read_index": [],
|
|
349
|
-
"write_index": [],
|
|
350
|
-
"cache": self.cache,
|
|
351
|
-
"use_cache": False,
|
|
352
|
-
}
|
|
353
|
-
|
|
354
|
-
# If we use constant-sized slicing, there are some "padding" queries tokens which FA has some issues with. In
|
|
355
|
-
# some models like Qwen3-4B-Instruct-2507, if we don't include these tokens in cumulative_seqlens_q, there are
|
|
356
|
-
# some NaNs in the output logits even for non-padded tokens.
|
|
357
|
-
if use_padding:
|
|
358
|
-
self.max_seqlen_q = max(self.max_seqlen_q, q_len - self.total_seqlen_q)
|
|
359
|
-
self.cumulative_seqlens_q[self.actual_batch_size + 1 :] = q_len
|
|
360
|
-
# FIXME: is there another way to avoid this? It has a very slight impact on performance (~5 tok/s)
|
|
361
|
-
|
|
362
|
-
# For the attributes that are lists of tensors, we construct list of tensor references
|
|
363
|
-
for i, (read_index_size, write_index_size) in enumerate(self.actual_index_sizes):
|
|
364
|
-
read_index_size = padded_kv_size if use_padding else read_index_size
|
|
365
|
-
write_index_size = padded_q_size if use_padding else write_index_size
|
|
366
|
-
kwargs["read_index"].append(self.read_index_storage[i][:read_index_size])
|
|
367
|
-
kwargs["write_index"].append(self.write_index_storage[i][:write_index_size])
|
|
368
|
-
|
|
369
|
-
# For the attributes that are dict of tensors, we replace the dict with a tensor if there is only one entry
|
|
370
|
-
layer_types = list(self.cumulative_seqlens_k.keys())
|
|
371
|
-
if len(layer_types) > 1:
|
|
372
|
-
for layer_type, seqlens_k in self.cumulative_seqlens_k.items():
|
|
373
|
-
kwargs["cu_seq_lens_k"][layer_type] = seqlens_k[: b_size + 1]
|
|
374
|
-
kwargs["max_seqlen_k"][layer_type] = self.max_seqlen_k[layer_type]
|
|
375
|
-
if self.attention_mask is not None:
|
|
376
|
-
k_len = padded_kv_size if use_padding else seqlens_k[b_size]
|
|
377
|
-
kwargs["attention_mask"][layer_type] = self.attention_mask[layer_type][..., :q_len, :k_len]
|
|
378
|
-
else:
|
|
379
|
-
layer_type = layer_types[0]
|
|
380
|
-
kwargs["cu_seq_lens_k"] = self.cumulative_seqlens_k[layer_type][: b_size + 1]
|
|
381
|
-
kwargs["max_seqlen_k"] = self.max_seqlen_k[layer_type]
|
|
382
|
-
if self.attention_mask is not None:
|
|
383
|
-
k_len = padded_kv_size if use_padding else self.cumulative_seqlens_k[layer_type][b_size]
|
|
384
|
-
kwargs["attention_mask"] = self.attention_mask[layer_type][..., :q_len, :k_len]
|
|
385
|
-
|
|
386
|
-
if self.attention_mask is None:
|
|
387
|
-
kwargs["attention_mask"] = None
|
|
388
|
-
return kwargs
|
|
150
|
+
# Setup inputs and outputs
|
|
151
|
+
self.inputs_and_outputs = ContinuousBatchingIOs(cache, config, model_device, model_dtype)
|
|
389
152
|
|
|
390
153
|
def __repr__(self) -> str:
|
|
391
154
|
return (
|
|
392
155
|
f"ContinuousBatchProcessor(input_queue={self.input_queue}, output_queue={self.output_queue}, "
|
|
393
156
|
f"active_requests={self.scheduler.active_requests}, waiting_requests={self.scheduler.waiting_requests})"
|
|
394
|
-
+ self.get_model_kwargs().__repr__()
|
|
157
|
+
+ self.inputs_and_outputs.get_model_kwargs().__repr__()
|
|
395
158
|
)
|
|
396
159
|
|
|
397
160
|
@traced
|
|
@@ -408,7 +171,7 @@ class ContinuousBatchProcessor:
|
|
|
408
171
|
break
|
|
409
172
|
except Exception as e:
|
|
410
173
|
logger.error(f"Error processing new request: {e}", exc_info=True)
|
|
411
|
-
state: RequestState = locals().get("state")
|
|
174
|
+
state: RequestState = locals().get("state") # type:ignore
|
|
412
175
|
if state is not None:
|
|
413
176
|
self._handle_request_error(e, state)
|
|
414
177
|
|
|
@@ -467,83 +230,30 @@ class ContinuousBatchProcessor:
|
|
|
467
230
|
self.metrics.record_queue_metrics(len(self.scheduler.active_requests), len(self.scheduler.waiting_requests))
|
|
468
231
|
|
|
469
232
|
# Schedule the next batch of requests, stop if there are no requests in the batch
|
|
470
|
-
|
|
233
|
+
requests_in_batch = self.scheduler.schedule_batch(self.max_batch_tokens, self.cache.num_pages)
|
|
471
234
|
|
|
472
235
|
# If requests_in_batch is None, it means we need to offload some requests if possible
|
|
473
|
-
if
|
|
236
|
+
if requests_in_batch is None:
|
|
474
237
|
if len(self.scheduler.active_requests) > 1:
|
|
475
238
|
self.soft_reset_one_request()
|
|
239
|
+
return False
|
|
476
240
|
else:
|
|
477
241
|
raise RuntimeError("No requests can be scheduled and no request can be offloaded.")
|
|
478
242
|
# If it's an empty list, it means we have no requests to process
|
|
243
|
+
self.requests_in_batch = requests_in_batch
|
|
479
244
|
if not self.requests_in_batch:
|
|
480
245
|
return False
|
|
246
|
+
|
|
481
247
|
# Otherwise, we can continue with the non-empty batch
|
|
482
248
|
self.metrics.record_batch_metrics(self.requests_in_batch)
|
|
249
|
+
self.inputs_and_outputs.prepare_batch_tensors(requests_in_batch)
|
|
483
250
|
|
|
484
|
-
#
|
|
485
|
-
self.reset_static_tensors() # FIXME: why does this make the generation faster?
|
|
486
|
-
|
|
487
|
-
# Prepare accumulators
|
|
488
|
-
self.actual_query_length = 0
|
|
489
|
-
self.actual_key_length = 0
|
|
490
|
-
self.actual_batch_size = 0
|
|
491
|
-
|
|
492
|
-
input_ids = []
|
|
493
|
-
position_ids = []
|
|
494
|
-
cumulative_seqlens_q = [0]
|
|
495
|
-
logits_indices = []
|
|
496
|
-
|
|
497
|
-
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
|
|
498
|
-
|
|
499
|
-
read_index = [[] for _ in range(self.cache.num_groups)]
|
|
500
|
-
write_index = [[] for _ in range(self.cache.num_groups)]
|
|
501
|
-
|
|
502
|
-
# Go through all the requests in the batch
|
|
503
|
-
for state in self.requests_in_batch:
|
|
504
|
-
# First we retrieve the lengths related to the request
|
|
505
|
-
past_length = state.position_offset
|
|
506
|
-
query_length = len(state.tokens_to_process)
|
|
507
|
-
seqlens_k = self.cache.get_seqlens_k(state.request_id, past_length, query_length)
|
|
508
|
-
|
|
509
|
-
# Then we update the total lengths that are used for slicing
|
|
510
|
-
self.actual_query_length += query_length
|
|
511
|
-
# total_key_length is used to slice the keys so we need to take the max of all the key lengths
|
|
512
|
-
self.actual_key_length += max(seqlens_k.values())
|
|
513
|
-
self.actual_batch_size += 1
|
|
514
|
-
# And the attribute tracking the position in the request object
|
|
515
|
-
state.position_offset += query_length
|
|
516
|
-
|
|
517
|
-
# Then we accumulate for the object used in the kwargs
|
|
518
|
-
input_ids.extend(state.tokens_to_process)
|
|
519
|
-
position_ids.extend(range(past_length, past_length + query_length))
|
|
520
|
-
cumulative_seqlens_q.append(cumulative_seqlens_q[-1] + query_length)
|
|
521
|
-
self.max_seqlen_q = max(self.max_seqlen_q, query_length)
|
|
522
|
-
|
|
523
|
-
if not state.remaining_prefill_tokens:
|
|
524
|
-
logits_indices.append(cumulative_seqlens_q[-1] - 1)
|
|
525
|
-
|
|
526
|
-
for layer_type, layer_type_seqlen_k in seqlens_k.items():
|
|
527
|
-
cumulative_seqlens_k[layer_type].append(cumulative_seqlens_k[layer_type][-1] + layer_type_seqlen_k)
|
|
528
|
-
self.max_seqlen_k[layer_type] = max(self.max_seqlen_k[layer_type], layer_type_seqlen_k)
|
|
529
|
-
|
|
530
|
-
self.cache.extend_read_indices(state.request_id, past_length, query_length, read_index)
|
|
531
|
-
self.cache.extend_write_indices(state.request_id, past_length, query_length, write_index)
|
|
532
|
-
|
|
533
|
-
# When looping over request is done, we can build the actual tensors
|
|
534
|
-
self._build_tensors(
|
|
535
|
-
input_ids,
|
|
536
|
-
position_ids,
|
|
537
|
-
read_index,
|
|
538
|
-
write_index,
|
|
539
|
-
cumulative_seqlens_q,
|
|
540
|
-
cumulative_seqlens_k,
|
|
541
|
-
logits_indices,
|
|
542
|
-
)
|
|
251
|
+
# Record the memory metrics of the KV cache
|
|
543
252
|
self.metrics.record_kv_cache_memory_metrics(self.cache)
|
|
544
|
-
|
|
545
253
|
if logger.isEnabledFor(logging.DEBUG):
|
|
546
|
-
|
|
254
|
+
cumulative_seqlens_q = self.inputs_and_outputs.cumulative_seqlens_q
|
|
255
|
+
cumulative_seqlens_k = self.inputs_and_outputs.cumulative_seqlens_k
|
|
256
|
+
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in cumulative_seqlens_k)
|
|
547
257
|
logger.debug(
|
|
548
258
|
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, "
|
|
549
259
|
f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. "
|
|
@@ -551,52 +261,6 @@ class ContinuousBatchProcessor:
|
|
|
551
261
|
)
|
|
552
262
|
return True
|
|
553
263
|
|
|
554
|
-
@traced
|
|
555
|
-
def _build_tensors(
|
|
556
|
-
self,
|
|
557
|
-
input_ids: list[int],
|
|
558
|
-
position_ids: list[int],
|
|
559
|
-
read_index: list[list[int]],
|
|
560
|
-
write_index: list[list[int]],
|
|
561
|
-
cumulative_seqlens_q: list[int],
|
|
562
|
-
cumulative_seqlens_k: dict[str, list[int]],
|
|
563
|
-
logits_indices: list[int],
|
|
564
|
-
) -> None:
|
|
565
|
-
"""Builds the actual tensors for the current batch, by modifying the already allocated tensors in place."""
|
|
566
|
-
to_tensor = partial(torch.tensor, **self.tensor_metadata)
|
|
567
|
-
|
|
568
|
-
# Those kwargs always have the same type regardless of the model
|
|
569
|
-
self.input_ids[:, : len(input_ids)] = to_tensor(input_ids)
|
|
570
|
-
self.position_ids[:, : len(position_ids)] = to_tensor(position_ids)
|
|
571
|
-
self.cumulative_seqlens_q[: len(cumulative_seqlens_q)] = to_tensor(cumulative_seqlens_q)
|
|
572
|
-
self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices)
|
|
573
|
-
self.total_seqlen_q = cumulative_seqlens_q[-1]
|
|
574
|
-
|
|
575
|
-
# Those kwargs are either dict of tensors or tensors, so we need to handle both cases
|
|
576
|
-
for layer_type, layer_type_seqlens_k in cumulative_seqlens_k.items():
|
|
577
|
-
self.cumulative_seqlens_k[layer_type][: len(layer_type_seqlens_k)] = to_tensor(layer_type_seqlens_k)
|
|
578
|
-
if self.attention_mask is not None:
|
|
579
|
-
build_attention_mask(
|
|
580
|
-
attention_mask=self.attention_mask[layer_type],
|
|
581
|
-
cumulative_seqlens_q=cumulative_seqlens_q,
|
|
582
|
-
cumulative_seqlens_k=layer_type_seqlens_k,
|
|
583
|
-
sliding_window=self.sliding_window if layer_type == "sliding_attention" else 1,
|
|
584
|
-
)
|
|
585
|
-
|
|
586
|
-
# The index only contain references to the storage tensors, so we update the storage and their references
|
|
587
|
-
self.read_index = []
|
|
588
|
-
self.write_index = []
|
|
589
|
-
for i, group_read_indices, group_write_indices in zip(count(), read_index, write_index):
|
|
590
|
-
self.read_index_storage[i][: len(group_read_indices)] = to_tensor(group_read_indices)
|
|
591
|
-
self.write_index_storage[i][: len(group_write_indices)] = to_tensor(group_write_indices)
|
|
592
|
-
self.actual_index_sizes[i] = (len(group_read_indices), len(group_write_indices))
|
|
593
|
-
|
|
594
|
-
@traced
|
|
595
|
-
def _get_new_tokens(self, num_new_tokens: int) -> list[int]:
|
|
596
|
-
indices = self.logits_indices[:num_new_tokens]
|
|
597
|
-
new_tokens = self.output_ids[indices]
|
|
598
|
-
return new_tokens.tolist()
|
|
599
|
-
|
|
600
264
|
@traced
|
|
601
265
|
def _maybe_send_output(self, state: RequestState) -> None:
|
|
602
266
|
"""Send output to the queue based on streaming mode and request state."""
|
|
@@ -606,13 +270,13 @@ class ContinuousBatchProcessor:
|
|
|
606
270
|
@traced
|
|
607
271
|
def update_batch(self) -> None:
|
|
608
272
|
"""Update request states based on generated tokens."""
|
|
609
|
-
new_tokens = self.
|
|
273
|
+
new_tokens = self.inputs_and_outputs.output_ids[: len(self.requests_in_batch)].tolist()
|
|
610
274
|
current_logits_index = 0
|
|
611
275
|
for state in self.requests_in_batch:
|
|
612
276
|
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
|
|
613
277
|
if len(state.remaining_prefill_tokens) == 0:
|
|
614
|
-
# If there
|
|
615
|
-
if state.generated_len() ==
|
|
278
|
+
# If there is just one temporary token, it means prefill just ended
|
|
279
|
+
if state.generated_len() == 1:
|
|
616
280
|
self.metrics.record_ttft_metric(state.created_time, state.request_id)
|
|
617
281
|
state.status = RequestStatus.DECODING
|
|
618
282
|
|
|
@@ -640,15 +304,15 @@ class ContinuousBatchProcessor:
|
|
|
640
304
|
copy_source, copy_destination = [], []
|
|
641
305
|
while self.scheduler._requests_to_fork:
|
|
642
306
|
# Get the number of children and reset it so it's not forked again
|
|
643
|
-
|
|
644
|
-
num_children =
|
|
645
|
-
|
|
307
|
+
state_to_fork = self.scheduler._requests_to_fork.pop()
|
|
308
|
+
num_children = state_to_fork.num_children
|
|
309
|
+
state_to_fork.num_children = 0
|
|
646
310
|
# Create the new request and add them to the scheduler
|
|
647
|
-
new_request_ids = [f"{
|
|
311
|
+
new_request_ids = [f"{state_to_fork.request_id}__child#{i}" for i in range(num_children)]
|
|
648
312
|
for new_request_id in new_request_ids:
|
|
649
|
-
self.scheduler.active_requests[new_request_id] =
|
|
313
|
+
self.scheduler.active_requests[new_request_id] = state_to_fork.fork(new_request_id)
|
|
650
314
|
# Fork the cache
|
|
651
|
-
copy_src, copy_dst = self.cache.fork_request(
|
|
315
|
+
copy_src, copy_dst = self.cache.fork_request(state_to_fork.request_id, new_request_ids)
|
|
652
316
|
copy_source.extend(copy_src)
|
|
653
317
|
copy_destination.extend(copy_dst)
|
|
654
318
|
# FIXME: if fork cant be done, create a new pending request without forking instead of crashing everything
|
|
@@ -692,8 +356,8 @@ class ContinuousBatchProcessor:
|
|
|
692
356
|
self.scheduler.waiting_requests_order.clear()
|
|
693
357
|
|
|
694
358
|
@traced
|
|
695
|
-
@torch.no_grad
|
|
696
|
-
def _generation_step(self, model: nn.Module, logit_processor:
|
|
359
|
+
@torch.no_grad()
|
|
360
|
+
def _generation_step(self, model: nn.Module, logit_processor: LogitsProcessorList, do_sample: bool) -> None:
|
|
697
361
|
"""Perform a single generation step."""
|
|
698
362
|
|
|
699
363
|
# If a compile config is specified, we compile the forward pass once in a wrapper
|
|
@@ -710,14 +374,18 @@ class ContinuousBatchProcessor:
|
|
|
710
374
|
|
|
711
375
|
# If inputs are static sized, we find the padded sizes of the queries and keys/values
|
|
712
376
|
if self._pad_inputs:
|
|
713
|
-
|
|
714
|
-
|
|
377
|
+
actual_query_length = self.inputs_and_outputs.actual_query_length
|
|
378
|
+
actual_index_sizes = self.inputs_and_outputs.actual_index_sizes
|
|
379
|
+
padded_q = pad_by_intervals(actual_query_length, self.max_batch_tokens, self.q_padding_intervals)
|
|
380
|
+
max_read_index_size = max(actual_index_sizes[i][0] for i in range(self.cache.num_groups))
|
|
715
381
|
# The space planned for query tokens will be added later, so we remove it from the space planned for KV
|
|
716
|
-
padded_read_index_size = pad_by_intervals(
|
|
382
|
+
padded_read_index_size = pad_by_intervals(
|
|
383
|
+
max_read_index_size, self.cache.num_pages, self.kv_padding_intervals
|
|
384
|
+
)
|
|
717
385
|
else:
|
|
718
386
|
padded_q, padded_read_index_size = 0, 0
|
|
719
387
|
# Retrieve the model kwargs with or without padding
|
|
720
|
-
batch_data = self.get_model_kwargs(padded_q, padded_read_index_size)
|
|
388
|
+
batch_data = self.inputs_and_outputs.get_model_kwargs(padded_q, padded_read_index_size)
|
|
721
389
|
|
|
722
390
|
# If we are not using cuda graphs, we perform the generation step and return
|
|
723
391
|
if self._graphs is None:
|
|
@@ -746,21 +414,23 @@ class ContinuousBatchProcessor:
|
|
|
746
414
|
|
|
747
415
|
@traced
|
|
748
416
|
def _forward_process_and_sample(
|
|
749
|
-
self, model: nn.Module, batch_data: dict, logit_processor:
|
|
417
|
+
self, model: nn.Module, batch_data: dict, logit_processor: LogitsProcessorList, do_sample: bool
|
|
750
418
|
) -> None:
|
|
751
419
|
"""This function performs the forward pass, logits processing, and sampling; which are broken down into smaller
|
|
752
420
|
function to be easier to trace with OpenTelemetry."""
|
|
753
421
|
logits = self._model_forward(model, batch_data)
|
|
754
422
|
# if self.log_prob_generation: batch_processor.output_probs.copy_(logits) # TODO
|
|
755
423
|
probs = self._process_logit(batch_data, logits, logit_processor)
|
|
756
|
-
self._sample(probs, do_sample)
|
|
424
|
+
self._sample(probs, batch_data, do_sample)
|
|
757
425
|
|
|
758
426
|
@traced(span_name="model_forward")
|
|
759
427
|
def _model_forward(self, model: nn.Module, batch_data: dict) -> torch.Tensor:
|
|
760
428
|
return model(**batch_data).logits
|
|
761
429
|
|
|
762
430
|
@traced(span_name="logit_processing")
|
|
763
|
-
def _process_logit(
|
|
431
|
+
def _process_logit(
|
|
432
|
+
self, batch_data: dict, logits: torch.Tensor, logit_processor: LogitsProcessorList
|
|
433
|
+
) -> torch.Tensor:
|
|
764
434
|
# Pass continuous batching context to logits processor if it supports it.
|
|
765
435
|
if hasattr(logit_processor, "set_continuous_batching_context"):
|
|
766
436
|
logit_processor.set_continuous_batching_context(batch_data["logits_indices"], batch_data["cu_seq_lens_q"])
|
|
@@ -770,13 +440,13 @@ class ContinuousBatchProcessor:
|
|
|
770
440
|
# NOTE: to be an exact match with generate, we should also convert logits2d to float32 here, but it's not needed in practice
|
|
771
441
|
logits_2d = logits.view(batch_size * seq_len, vocab_size)
|
|
772
442
|
input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len)
|
|
773
|
-
# Process with 2D tensors
|
|
774
|
-
processed_logits_2d = logit_processor(input_ids_2d, logits_2d)
|
|
443
|
+
# Process with 2D tensors#
|
|
444
|
+
processed_logits_2d = logit_processor(input_ids_2d, logits_2d) # type: ignore[arg-type]
|
|
775
445
|
# Reshape back to 3D
|
|
776
446
|
return processed_logits_2d.view(batch_size, seq_len, vocab_size)
|
|
777
447
|
|
|
778
448
|
@traced(span_name="sampling")
|
|
779
|
-
def _sample(self, probs: torch.Tensor, do_sample: bool) -> None:
|
|
449
|
+
def _sample(self, probs: torch.Tensor, batch_data: dict, do_sample: bool) -> None:
|
|
780
450
|
if do_sample:
|
|
781
451
|
probs = nn.functional.softmax(probs, dim=-1)
|
|
782
452
|
# probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1]
|
|
@@ -785,7 +455,10 @@ class ContinuousBatchProcessor:
|
|
|
785
455
|
next_tokens = torch.argmax(probs, dim=-1) # shape is [1, seq_len]
|
|
786
456
|
next_tokens = next_tokens.squeeze(0) # shape is [seq_len]
|
|
787
457
|
tokens = next_tokens.size(0) # Get seq_len dimension
|
|
788
|
-
|
|
458
|
+
#
|
|
459
|
+
indices = batch_data["logits_indices"][:tokens]
|
|
460
|
+
next_tokens = next_tokens[indices]
|
|
461
|
+
self.inputs_and_outputs.output_ids[:tokens].copy_(next_tokens)
|
|
789
462
|
|
|
790
463
|
|
|
791
464
|
# Manager Class (User Interface)
|
|
@@ -799,7 +472,7 @@ class ContinuousBatchingManager:
|
|
|
799
472
|
|
|
800
473
|
def __init__(
|
|
801
474
|
self,
|
|
802
|
-
model:
|
|
475
|
+
model: ProtoPretrainedModel,
|
|
803
476
|
generation_config: GenerationConfig,
|
|
804
477
|
manual_eviction: bool = False,
|
|
805
478
|
max_queue_size: int = 0,
|
|
@@ -840,7 +513,7 @@ class ContinuousBatchingManager:
|
|
|
840
513
|
self.generation_config = generation_config
|
|
841
514
|
self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
|
|
842
515
|
self.do_sample = getattr(generation_config, "do_sample", True)
|
|
843
|
-
self.logit_processor = self.model._get_logits_processor(generation_config)
|
|
516
|
+
self.logit_processor: LogitsProcessorList = self.model._get_logits_processor(generation_config)
|
|
844
517
|
num_return_sequences = getattr(generation_config, "num_return_sequences", None)
|
|
845
518
|
self.num_return_sequences = num_return_sequences if num_return_sequences is not None else 1
|
|
846
519
|
|
|
@@ -879,6 +552,11 @@ class ContinuousBatchingManager:
|
|
|
879
552
|
If none of the above criteria are met, we use a default heuristic based on the attention implementation: we turn
|
|
880
553
|
on cuda graphs if and only if no attention mask is needed.
|
|
881
554
|
"""
|
|
555
|
+
# If cuda is not available, we cannot use cuda graphs
|
|
556
|
+
if not torch.cuda.is_available():
|
|
557
|
+
if use_cuda_graph:
|
|
558
|
+
logger.warning(f"use_cuda_graph is True but {torch.cuda.is_available() = }: turning off cuda graphs.")
|
|
559
|
+
return False
|
|
882
560
|
# If use_cuda_graph is specified, we follow the user's choice
|
|
883
561
|
if use_cuda_graph is not None:
|
|
884
562
|
return use_cuda_graph
|
|
@@ -902,7 +580,7 @@ class ContinuousBatchingManager:
|
|
|
902
580
|
logger.warning(
|
|
903
581
|
f"No behavior specified for use_cuda_graph, defaulting to {use_cuda_graph = } because "
|
|
904
582
|
f"{self.model.config._attn_implementation = }. If you want to save memory, turn off cuda graphs, but "
|
|
905
|
-
"they
|
|
583
|
+
"they tend to improve performances by a lot."
|
|
906
584
|
)
|
|
907
585
|
return use_cuda_graph
|
|
908
586
|
|
|
@@ -1013,14 +691,17 @@ class ContinuousBatchingManager:
|
|
|
1013
691
|
streaming: bool = False,
|
|
1014
692
|
record_timestamps: bool = False,
|
|
1015
693
|
) -> None:
|
|
1016
|
-
#
|
|
694
|
+
# Infer the request ids of all incoming requests
|
|
695
|
+
with self._request_lock:
|
|
696
|
+
request_ids = [f"req_{i}" for i in range(self._request_counter, self._request_counter + len(inputs))]
|
|
697
|
+
self._request_counter += len(inputs)
|
|
698
|
+
# If there is prefix sharing, we sort the inputs to maximize cache hits but keep the order of the requests
|
|
699
|
+
ids_and_inputs = list(zip(request_ids, inputs))
|
|
1017
700
|
if self._use_prefix_sharing:
|
|
1018
|
-
|
|
701
|
+
ids_and_inputs = sorted(ids_and_inputs, key=lambda x: x[1], reverse=True)
|
|
1019
702
|
# Add requests in order
|
|
1020
|
-
for input_ids in
|
|
1021
|
-
self.add_request(
|
|
1022
|
-
input_ids, max_new_tokens=max_new_tokens, streaming=streaming, record_timestamps=record_timestamps
|
|
1023
|
-
)
|
|
703
|
+
for request_id, input_ids in ids_and_inputs:
|
|
704
|
+
self.add_request(input_ids, request_id, max_new_tokens, streaming, record_timestamps)
|
|
1024
705
|
|
|
1025
706
|
def cancel_request(self, request_id: str) -> None:
|
|
1026
707
|
"""Cancel a request by its ID.
|
|
@@ -1073,7 +754,9 @@ class ContinuousBatchingManager:
|
|
|
1073
754
|
|
|
1074
755
|
@traced
|
|
1075
756
|
def _generation_step(self) -> None:
|
|
1076
|
-
"""Perform a single generation step. This is cuda graphed"""
|
|
757
|
+
"""Perform a single generation step. This is mostly cuda graphed"""
|
|
758
|
+
if self.batch_processor is None:
|
|
759
|
+
raise RuntimeError("Tried to perform a generation step before the batch processor was initialized.")
|
|
1077
760
|
self.batch_processor._generation_step(self.model, self.logit_processor, self.do_sample)
|
|
1078
761
|
|
|
1079
762
|
def _run_generation_loop(self) -> None:
|
|
@@ -1139,14 +822,6 @@ class ContinuousBatchingManager:
|
|
|
1139
822
|
# Loop body ends if there is no requests in the batch
|
|
1140
823
|
if not batch_processor.prepare_next_batch():
|
|
1141
824
|
return
|
|
1142
|
-
# Debug logging of the current memory usage -- commented out because it's often not used even in debug
|
|
1143
|
-
# if logger.level < logging.DEBUG:
|
|
1144
|
-
# device, total, reserved, allocated = get_device_and_memory_breakdown()
|
|
1145
|
-
# available_memory = total - max(allocated, reserved)
|
|
1146
|
-
# logger.debug(
|
|
1147
|
-
# f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}, Available: {available_memory}"
|
|
1148
|
-
# )
|
|
1149
|
-
|
|
1150
825
|
self._generation_step()
|
|
1151
826
|
batch_processor.update_batch()
|
|
1152
827
|
|
|
@@ -1181,6 +856,8 @@ class ContinuousBatchingManager:
|
|
|
1181
856
|
class ContinuousMixin:
|
|
1182
857
|
"""Mixin class for models to add continuous batching capabilities."""
|
|
1183
858
|
|
|
859
|
+
generation_config: GenerationConfig
|
|
860
|
+
|
|
1184
861
|
@contextmanager
|
|
1185
862
|
def continuous_batching_context_manager(
|
|
1186
863
|
self,
|
|
@@ -1246,7 +923,7 @@ class ContinuousMixin:
|
|
|
1246
923
|
|
|
1247
924
|
# Create and return the manager
|
|
1248
925
|
return ContinuousBatchingManager(
|
|
1249
|
-
model=self,
|
|
926
|
+
model=self, # type: ignore
|
|
1250
927
|
generation_config=gen_config,
|
|
1251
928
|
manual_eviction=manual_eviction,
|
|
1252
929
|
max_queue_size=max_queue_size,
|
|
@@ -1328,8 +1005,19 @@ class ContinuousMixin:
|
|
|
1328
1005
|
else:
|
|
1329
1006
|
if not manager.is_running():
|
|
1330
1007
|
logger.error("Generation thread terminated unexpectedly.")
|
|
1008
|
+
# This helps get some information in stdout
|
|
1009
|
+
print("Returning results of generate_batch despite unexpected termination.")
|
|
1331
1010
|
break
|
|
1332
1011
|
|
|
1333
1012
|
except Exception as e:
|
|
1334
1013
|
logger.error(f"Error during batch generation: {e}", exc_info=True)
|
|
1335
|
-
|
|
1014
|
+
# Re-order requests to match the order of the inputs
|
|
1015
|
+
reordered_results = {}
|
|
1016
|
+
for i in range(len(inputs)):
|
|
1017
|
+
# We cannot guarantee that the generation succeeded for all requests, so we need to check if the request is in the results
|
|
1018
|
+
result = results.get(f"req_{i}")
|
|
1019
|
+
if result is not None:
|
|
1020
|
+
reordered_results[f"req_{i}"] = result
|
|
1021
|
+
else:
|
|
1022
|
+
logger.error(f"Request req_{i} not found in results.")
|
|
1023
|
+
return reordered_results
|