transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -29,7 +29,7 @@ from tqdm import tqdm
|
|
|
29
29
|
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
30
30
|
|
|
31
31
|
from ...configuration_utils import PretrainedConfig
|
|
32
|
-
from ...generation.configuration_utils import GenerationConfig
|
|
32
|
+
from ...generation.configuration_utils import CompileConfig, GenerationConfig
|
|
33
33
|
from ...generation.logits_process import LogitsProcessor
|
|
34
34
|
from ...utils.logging import logging
|
|
35
35
|
from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
|
|
@@ -45,17 +45,20 @@ generation goes on, there are two dimensions that change:
|
|
|
45
45
|
- the number of keys/values tokens (KV), which grows as the cache does
|
|
46
46
|
|
|
47
47
|
To solve this, we slice along those dimensions to fixed lengths. The size of the slices is controlled by the variables
|
|
48
|
-
|
|
49
|
-
number of queries tokens is 1000, and
|
|
50
|
-
1000 / 4 = 250 tokens, ie. to 250, 500, 750 or 1000 queries tokens.
|
|
48
|
+
num_x_padding_intervals: NUM_X_PADDING_INTERVALS means that we create at most NUM_X_PADDING_INTERVALS graphs for the X
|
|
49
|
+
dimension. So if the maximum number of queries tokens is 1000, and NUM_Q_PADDING_INTERVALS is 4, we will slice the
|
|
50
|
+
number of queries token by intervals of 1000 / 4 = 250 tokens, ie. to 250, 500, 750 or 1000 queries tokens.
|
|
51
51
|
|
|
52
52
|
Smaller slices means more granularity and thus less padding. But since each graph takes up space on the GPU and time to
|
|
53
53
|
create, we don't want to many graphs. And since the size of the KV dimension is the number of queries tokens plus the
|
|
54
54
|
number of tokens cached, dimension of KV is usually much larger than the dimension of Q. So we have more granularity
|
|
55
55
|
for the KV dimension than the query dimension.
|
|
56
|
+
|
|
57
|
+
This variable used to be called NUM_X_CUDA_GRAPHS, but we renamed it to NUM_X_PADDING_INTERVALS because it is used for
|
|
58
|
+
padding in the case of cuda graphs AND torch.compile.
|
|
56
59
|
"""
|
|
57
|
-
|
|
58
|
-
|
|
60
|
+
NUM_Q_PADDING_INTERVALS = 4
|
|
61
|
+
NUM_KV_PADDING_INTERVALS = 8
|
|
59
62
|
|
|
60
63
|
|
|
61
64
|
def pad_by_intervals(size: int, max_value: int, nb_intervals: int) -> int:
|
|
@@ -63,7 +66,7 @@ def pad_by_intervals(size: int, max_value: int, nb_intervals: int) -> int:
|
|
|
63
66
|
interval_size = max_value // nb_intervals
|
|
64
67
|
if interval_size == 0:
|
|
65
68
|
return max_value
|
|
66
|
-
padded = ceil(size / interval_size) * interval_size
|
|
69
|
+
padded = ceil(size / interval_size) * interval_size if size > 0 else interval_size
|
|
67
70
|
return min(padded, max_value)
|
|
68
71
|
|
|
69
72
|
|
|
@@ -188,6 +191,8 @@ class ContinuousBatchProcessor:
|
|
|
188
191
|
scheduler: Scheduler,
|
|
189
192
|
manual_eviction: bool,
|
|
190
193
|
use_cuda_graph: bool,
|
|
194
|
+
q_padding_intervals: int,
|
|
195
|
+
kv_padding_intervals: int,
|
|
191
196
|
) -> None:
|
|
192
197
|
"""Initialize the continuous batch processor.
|
|
193
198
|
|
|
@@ -221,7 +226,14 @@ class ContinuousBatchProcessor:
|
|
|
221
226
|
# Accumulator for batch scheduling
|
|
222
227
|
self.requests_in_batch: list[RequestState] = []
|
|
223
228
|
# Cuda graphs for the generation step
|
|
229
|
+
self.q_padding_intervals = q_padding_intervals
|
|
230
|
+
self.kv_padding_intervals = kv_padding_intervals
|
|
224
231
|
self._graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] | None = {} if use_cuda_graph else None
|
|
232
|
+
# Compile-related arguments
|
|
233
|
+
self.compile_config: CompileConfig | None = getattr(generation_config, "compile_config", None)
|
|
234
|
+
self._forward_process_and_sample_is_compiled = False
|
|
235
|
+
|
|
236
|
+
self._pad_inputs = use_cuda_graph or (self.compile_config is not None and not self.compile_config.dynamic)
|
|
225
237
|
|
|
226
238
|
# Set up metrics collector
|
|
227
239
|
self.max_batch_tokens = cache.max_batch_tokens
|
|
@@ -247,7 +259,7 @@ class ContinuousBatchProcessor:
|
|
|
247
259
|
self.cumulative_seqlens_q = torch.empty((self.max_batch_tokens + 1,), **self.tensor_metadata)
|
|
248
260
|
self.max_seqlen_q = 0
|
|
249
261
|
self.logits_indices = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
|
|
250
|
-
self.output_ids = torch.empty((
|
|
262
|
+
self.output_ids = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
|
|
251
263
|
|
|
252
264
|
# For some kwargs, we have a dict of tensors with as many items as there are attention types
|
|
253
265
|
layer_types = getattr(self.config, "layer_types", None)
|
|
@@ -299,7 +311,7 @@ class ContinuousBatchProcessor:
|
|
|
299
311
|
self.cumulative_seqlens_q[: b_size + 1].zero_()
|
|
300
312
|
self.max_seqlen_q = 0
|
|
301
313
|
self.logits_indices[:q_len].fill_(-1)
|
|
302
|
-
self.output_ids[
|
|
314
|
+
self.output_ids[:q_len].fill_(-1)
|
|
303
315
|
|
|
304
316
|
# Reset the attributes that are either tensors or dict of tensors
|
|
305
317
|
for layer_type in self.cumulative_seqlens_k:
|
|
@@ -435,7 +447,7 @@ class ContinuousBatchProcessor:
|
|
|
435
447
|
self.metrics.record_batch_metrics(self.requests_in_batch)
|
|
436
448
|
|
|
437
449
|
# Reset the static tensors used for storage
|
|
438
|
-
self.reset_static_tensors() #
|
|
450
|
+
self.reset_static_tensors() # FIXME: why does this make the generation faster?
|
|
439
451
|
|
|
440
452
|
# Prepare accumulators
|
|
441
453
|
self.actual_query_length = 0
|
|
@@ -545,13 +557,10 @@ class ContinuousBatchProcessor:
|
|
|
545
557
|
self.actual_index_sizes[i] = (len(group_read_indices), len(group_write_indices))
|
|
546
558
|
|
|
547
559
|
@traced
|
|
548
|
-
def
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
except Exception:
|
|
553
|
-
return [0, 1]
|
|
554
|
-
return [0, 0]
|
|
560
|
+
def _get_new_tokens(self, num_new_tokens: int) -> list[int]:
|
|
561
|
+
indices = self.logits_indices[:num_new_tokens]
|
|
562
|
+
new_tokens = self.output_ids[indices]
|
|
563
|
+
return new_tokens.tolist()
|
|
555
564
|
|
|
556
565
|
@traced
|
|
557
566
|
def _maybe_send_output(self, state: RequestState) -> None:
|
|
@@ -562,29 +571,56 @@ class ContinuousBatchProcessor:
|
|
|
562
571
|
@traced
|
|
563
572
|
def update_batch(self) -> None:
|
|
564
573
|
"""Update request states based on generated tokens."""
|
|
565
|
-
|
|
566
|
-
|
|
574
|
+
new_tokens = self._get_new_tokens(len(self.requests_in_batch))
|
|
575
|
+
current_logits_index = 0
|
|
576
|
+
for state in self.requests_in_batch:
|
|
567
577
|
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
|
|
568
578
|
if len(state.remaining_prefill_tokens) == 0:
|
|
569
|
-
|
|
570
|
-
state.
|
|
571
|
-
|
|
579
|
+
# If there are no generated tokens yet, it means prefill just ended
|
|
580
|
+
if state.generated_len() == 0:
|
|
581
|
+
self.metrics.record_ttft_metric(state.created_time, state.request_id)
|
|
582
|
+
state.status = RequestStatus.DECODING
|
|
583
|
+
|
|
584
|
+
token = new_tokens[current_logits_index]
|
|
572
585
|
state.tokens_to_process = [token]
|
|
586
|
+
current_logits_index += 1
|
|
587
|
+
|
|
573
588
|
# Update the request and stop if it is complete
|
|
574
589
|
is_finished = state.update_and_check_completion(token)
|
|
575
590
|
# We mark the completed blocks as such
|
|
576
|
-
self.cache.
|
|
591
|
+
self.cache.mark_shareable_blocks_as_complete(state)
|
|
577
592
|
if is_finished:
|
|
578
593
|
self.metrics.record_request_completion(state.created_time, state.request_id)
|
|
579
594
|
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
|
|
580
595
|
self._maybe_send_output(state)
|
|
581
596
|
# Otherwise, the request is still prefilling, but the prefill has been split
|
|
582
597
|
elif state.status == RequestStatus.PREFILLING_SPLIT:
|
|
583
|
-
self.cache.
|
|
598
|
+
self.cache.mark_shareable_blocks_as_complete(state)
|
|
584
599
|
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
|
585
600
|
else:
|
|
586
601
|
raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}")
|
|
587
602
|
|
|
603
|
+
# If some requests need to be forked, we do it now
|
|
604
|
+
copy_source, copy_destination = [], []
|
|
605
|
+
while self.scheduler._requests_to_fork:
|
|
606
|
+
# Get the number of children and reset it so it's not forked again
|
|
607
|
+
state = self.scheduler._requests_to_fork.pop()
|
|
608
|
+
num_children = state.num_children
|
|
609
|
+
state.num_children = 0
|
|
610
|
+
# Create the new request and add them to the scheduler
|
|
611
|
+
new_request_ids = [f"{state.request_id}__child#{i}" for i in range(num_children)]
|
|
612
|
+
for new_request_id in new_request_ids:
|
|
613
|
+
self.scheduler.active_requests[new_request_id] = state.fork(new_request_id)
|
|
614
|
+
# Fork the cache
|
|
615
|
+
copy_src, copy_dst = self.cache.fork_request(state.request_id, new_request_ids)
|
|
616
|
+
copy_source.extend(copy_src)
|
|
617
|
+
copy_destination.extend(copy_dst)
|
|
618
|
+
# FIXME: if fork cant be done, create a new pending request without forking instead of crashing everything
|
|
619
|
+
|
|
620
|
+
# The copy induced by the fork is done in one go (if it's even needed)
|
|
621
|
+
if copy_source:
|
|
622
|
+
self.cache.copy_cache(copy_source, copy_destination)
|
|
623
|
+
|
|
588
624
|
if self.cache.get_num_free_blocks() == 0:
|
|
589
625
|
raise ValueError("No more free blocks")
|
|
590
626
|
|
|
@@ -627,28 +663,39 @@ class ContinuousBatchProcessor:
|
|
|
627
663
|
def _generation_step(self, model: nn.Module, logit_processor: LogitsProcessor, do_sample: bool) -> None:
|
|
628
664
|
"""Perform a single generation step."""
|
|
629
665
|
|
|
630
|
-
# If
|
|
666
|
+
# If a compile config is specified, we compile the forward pass once in a wrapper
|
|
667
|
+
if self.compile_config is not None and not self._forward_process_and_sample_is_compiled:
|
|
668
|
+
self._forward_process_and_sample = torch.compile(
|
|
669
|
+
self._forward_process_and_sample,
|
|
670
|
+
fullgraph=self.compile_config.fullgraph,
|
|
671
|
+
mode=self.compile_config.mode,
|
|
672
|
+
dynamic=self.compile_config.dynamic,
|
|
673
|
+
backend=self.compile_config.backend,
|
|
674
|
+
options=self.compile_config.options,
|
|
675
|
+
)
|
|
676
|
+
self._forward_process_and_sample_is_compiled = True
|
|
677
|
+
|
|
678
|
+
# If inputs are static sized, we find the padded sizes of the queries and keys/values
|
|
679
|
+
if self._pad_inputs:
|
|
680
|
+
padded_q = pad_by_intervals(self.actual_query_length, self.max_batch_tokens, self.q_padding_intervals)
|
|
681
|
+
max_read_index_size = max(self.actual_index_sizes[i][0] for i in range(self.cache.num_groups))
|
|
682
|
+
padded_read_index_size = pad_by_intervals(
|
|
683
|
+
max_read_index_size - self.max_batch_tokens,
|
|
684
|
+
self.cache.num_blocks * self.cache.block_size,
|
|
685
|
+
self.kv_padding_intervals,
|
|
686
|
+
)
|
|
687
|
+
else:
|
|
688
|
+
padded_q, padded_read_index_size = 0, 0
|
|
689
|
+
# Retrieve the model kwargs with or without padding
|
|
690
|
+
batch_data = self.get_model_kwargs(padded_q, padded_read_index_size)
|
|
691
|
+
|
|
692
|
+
# If we are not using cuda graphs, we perform the generation step and return
|
|
631
693
|
if self._graphs is None:
|
|
632
|
-
batch_data = self.get_model_kwargs()
|
|
633
694
|
self._forward_process_and_sample(model, batch_data, logit_processor, do_sample)
|
|
634
695
|
return None
|
|
635
696
|
|
|
636
|
-
# Determine the padded size of the queries and keys/values
|
|
637
|
-
padded_q = pad_by_intervals(self.actual_query_length, self.max_batch_tokens, NUM_Q_CUDA_GRAPHS)
|
|
638
|
-
|
|
639
|
-
max_read_index_size = max(self.actual_index_sizes[i][0] for i in range(self.cache.num_groups))
|
|
640
|
-
padded_read_index_size = pad_by_intervals(
|
|
641
|
-
max_read_index_size - self.max_batch_tokens,
|
|
642
|
-
self.cache.num_blocks * self.cache.block_size,
|
|
643
|
-
NUM_KV_CUDA_GRAPHS,
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
# Get the batch data and the associated graph
|
|
647
|
-
batch_data = self.get_model_kwargs(padded_q, padded_read_index_size)
|
|
648
|
-
|
|
649
|
-
graph = self._graphs.get((padded_q, padded_read_index_size))
|
|
650
|
-
|
|
651
697
|
# If we have a graph that fits, we replay it
|
|
698
|
+
graph = self._graphs.get((padded_q, padded_read_index_size))
|
|
652
699
|
if graph is not None:
|
|
653
700
|
graph.replay()
|
|
654
701
|
return None
|
|
@@ -673,7 +720,6 @@ class ContinuousBatchProcessor:
|
|
|
673
720
|
) -> None:
|
|
674
721
|
"""This function performs the forward pass, logits processing, and sampling; which are broken down into smaller
|
|
675
722
|
function to be easier to trace with OpenTelemetry."""
|
|
676
|
-
# with torch.no_grad():
|
|
677
723
|
logits = self._model_forward(model, batch_data)
|
|
678
724
|
# if self.log_prob_generation: batch_processor.output_probs.copy_(logits) # TODO
|
|
679
725
|
probs = self._process_logit(batch_data, logits, logit_processor)
|
|
@@ -691,6 +737,7 @@ class ContinuousBatchProcessor:
|
|
|
691
737
|
# Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size]
|
|
692
738
|
# but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size]
|
|
693
739
|
batch_size, seq_len, vocab_size = logits.shape
|
|
740
|
+
# NOTE: to be an exact match with generate, we should also convert logits2d to float32 here, but it's not needed in practice
|
|
694
741
|
logits_2d = logits.view(batch_size * seq_len, vocab_size)
|
|
695
742
|
input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len)
|
|
696
743
|
# Process with 2D tensors
|
|
@@ -704,12 +751,11 @@ class ContinuousBatchProcessor:
|
|
|
704
751
|
probs = nn.functional.softmax(probs, dim=-1)
|
|
705
752
|
# probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1]
|
|
706
753
|
next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len]
|
|
707
|
-
# Add batch dimension back to match argmax output
|
|
708
|
-
next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len]
|
|
709
754
|
else:
|
|
710
|
-
next_tokens = torch.argmax(probs, dim=-1) #
|
|
711
|
-
|
|
712
|
-
|
|
755
|
+
next_tokens = torch.argmax(probs, dim=-1) # shape is [1, seq_len]
|
|
756
|
+
next_tokens = next_tokens.squeeze(0) # shape is [seq_len]
|
|
757
|
+
tokens = next_tokens.size(0) # Get seq_len dimension
|
|
758
|
+
self.output_ids[:tokens].copy_(next_tokens)
|
|
713
759
|
|
|
714
760
|
|
|
715
761
|
# Manager Class (User Interface)
|
|
@@ -727,9 +773,9 @@ class ContinuousBatchingManager:
|
|
|
727
773
|
generation_config: GenerationConfig,
|
|
728
774
|
manual_eviction: bool = False,
|
|
729
775
|
max_queue_size: int = 0,
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
776
|
+
num_q_padding_intervals: int = 0,
|
|
777
|
+
num_kv_padding_intervals: int = 0,
|
|
778
|
+
allow_block_sharing: bool = True,
|
|
733
779
|
) -> None:
|
|
734
780
|
"""Initialize the continuous batching manager.
|
|
735
781
|
|
|
@@ -737,65 +783,98 @@ class ContinuousBatchingManager:
|
|
|
737
783
|
model: The language model for generation
|
|
738
784
|
generation_config: Configuration for generation parameters
|
|
739
785
|
max_queue_size: Maximum size of the request queue (0 = unlimited)
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
786
|
+
num_q_padding_intervals: (optional) Number of intervals used to pad the query dimension
|
|
787
|
+
num_kv_padding_intervals: (optional) Number of intervals used to pad the keys/values dimension
|
|
788
|
+
allow_block_sharing: (optional) Whether to allow block sharing if the model has some full attention layers
|
|
743
789
|
"""
|
|
790
|
+
# Reload paged version of the attention implementation if necessary
|
|
744
791
|
if "paged|" not in model.config._attn_implementation:
|
|
745
|
-
|
|
746
|
-
model.config._attn_implementation = attn_implementation
|
|
747
|
-
|
|
748
|
-
# lazy loading flash attention including kernel variations
|
|
749
|
-
if "flash" in attn_implementation:
|
|
750
|
-
from ...modeling_flash_attention_utils import lazy_import_paged_flash_attention
|
|
751
|
-
|
|
752
|
-
lazy_import_paged_flash_attention(attn_implementation)
|
|
792
|
+
model.set_attn_implementation(f"paged|{model.config._attn_implementation}")
|
|
753
793
|
|
|
794
|
+
# Internal arguments
|
|
754
795
|
self.model = model.eval()
|
|
755
|
-
|
|
756
|
-
self.
|
|
796
|
+
self.manual_eviction = manual_eviction
|
|
797
|
+
self._allow_block_sharing = allow_block_sharing
|
|
798
|
+
self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created
|
|
799
|
+
|
|
757
800
|
self.input_queue = queue.Queue(maxsize=max_queue_size)
|
|
758
801
|
self.output_queue = queue.Queue()
|
|
759
802
|
self.stop_event = threading.Event()
|
|
760
|
-
self.
|
|
803
|
+
self.batch_processor: ContinuousBatchProcessor | None = None
|
|
761
804
|
self._generation_thread = None
|
|
762
805
|
self._request_counter = 0
|
|
763
806
|
self._request_lock = threading.Lock()
|
|
764
|
-
|
|
807
|
+
|
|
808
|
+
# Generation config related arguments
|
|
809
|
+
generation_config = model.generation_config if generation_config is None else generation_config
|
|
810
|
+
self.generation_config = generation_config
|
|
811
|
+
self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
|
|
765
812
|
self.do_sample = getattr(generation_config, "do_sample", True)
|
|
766
813
|
self.logit_processor = self.model._get_logits_processor(generation_config)
|
|
767
|
-
|
|
768
|
-
self.profile = getattr(generation_config, "profile", False) # TODO: not supported yet
|
|
769
|
-
self.manual_eviction = manual_eviction
|
|
770
|
-
self.batch_processor: ContinuousBatchProcessor | None = None
|
|
814
|
+
self.num_return_sequences = getattr(generation_config, "num_return_sequences", 1)
|
|
771
815
|
|
|
772
|
-
self.
|
|
816
|
+
# self.model.generation_config.top_p = None NOTE: figure out why this was here
|
|
773
817
|
|
|
774
|
-
#
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
else:
|
|
782
|
-
# Attention implementations where an attention mask is needed suffer a lot more from the padding associated
|
|
783
|
-
# with cuda graphs, so default is to turn cuda graphs off for those implementations
|
|
784
|
-
self.use_cuda_graph = not attn_mask_is_needed(self.model.config)
|
|
785
|
-
logger.warning(
|
|
786
|
-
f"No behavior specified for use_cuda_graph, defaulting to {self.use_cuda_graph = } because "
|
|
787
|
-
f"{self.model.config._attn_implementation = }. If you want to save memory, turn off cuda graphs, but "
|
|
788
|
-
"they can improve performances."
|
|
789
|
-
)
|
|
818
|
+
# Cuda graph behavior is determined below using either user-specified arguments or heuristics
|
|
819
|
+
self.use_cuda_graph = self._decide_use_cuda_graphs(
|
|
820
|
+
use_cuda_graph=getattr(generation_config, "use_cuda_graph", None),
|
|
821
|
+
num_q_padding_intervals=num_q_padding_intervals,
|
|
822
|
+
num_kv_padding_intervals=num_kv_padding_intervals,
|
|
823
|
+
compile_config=getattr(generation_config, "compile_config", None),
|
|
824
|
+
)
|
|
790
825
|
|
|
791
|
-
#
|
|
792
|
-
if
|
|
793
|
-
|
|
794
|
-
|
|
826
|
+
# We set the number of padding intervals for Q and KV
|
|
827
|
+
self.q_padding_intervals = num_q_padding_intervals if num_q_padding_intervals > 0 else NUM_Q_PADDING_INTERVALS
|
|
828
|
+
self.kv_padding_intervals = (
|
|
829
|
+
num_kv_padding_intervals if num_kv_padding_intervals > 0 else NUM_KV_PADDING_INTERVALS
|
|
830
|
+
)
|
|
795
831
|
|
|
832
|
+
# Log probability generation is not supported yet (TODO)
|
|
796
833
|
if self.log_prob_generation:
|
|
797
834
|
raise NotImplementedError("log_prob_generation is not supported yet")
|
|
798
835
|
|
|
836
|
+
def _decide_use_cuda_graphs(
|
|
837
|
+
self,
|
|
838
|
+
use_cuda_graph: bool | None,
|
|
839
|
+
num_q_padding_intervals: int,
|
|
840
|
+
num_kv_padding_intervals: int,
|
|
841
|
+
compile_config: CompileConfig | None,
|
|
842
|
+
) -> bool:
|
|
843
|
+
"""Returns whether or not to use cuda graphs for continuous batching, depending on the following criteria:
|
|
844
|
+
- (use_cuda_graph) which is the user choice
|
|
845
|
+
- (num_q_padding_intervals) or (num_kv_padding_intervals) which is used to pad inputs: if it was specified by
|
|
846
|
+
the user, it's probable they want to use cuda graphs so inputs need to be padded
|
|
847
|
+
- (compile_config): if compile is on, turn on cuda graphs unless the compile mode uses its own cudagraphs
|
|
848
|
+
If none of the above criteria are met, we use a default heuristic based on the attention implementation: we turn
|
|
849
|
+
on cuda graphs if and only if no attention mask is needed.
|
|
850
|
+
"""
|
|
851
|
+
# If use_cuda_graph is specified, we follow the user's choice
|
|
852
|
+
if use_cuda_graph is not None:
|
|
853
|
+
return use_cuda_graph
|
|
854
|
+
# If a number of padding intervals was specified for either Q or KV, we activate cuda graphs
|
|
855
|
+
if num_q_padding_intervals > 0 or num_kv_padding_intervals > 0:
|
|
856
|
+
return True
|
|
857
|
+
# If a compile config was found, turn off cuda graphs if the compile config already uses them
|
|
858
|
+
if compile_config is not None:
|
|
859
|
+
options = torch._inductor.list_mode_options().get(compile_config.mode, compile_config.options)
|
|
860
|
+
compile_uses_cudagraphs = options.get("triton.cudagraphs", False)
|
|
861
|
+
if compile_uses_cudagraphs:
|
|
862
|
+
logger.warning(
|
|
863
|
+
f"Compile config {compile_config.mode = } uses cudagraphs, which usually does not work well with "
|
|
864
|
+
"continuous batching. We recommend using mode 'default' or 'max-autotune-no-cudagraphs' instead."
|
|
865
|
+
)
|
|
866
|
+
return not compile_uses_cudagraphs # TODO: should this also match the dynamic shapes?
|
|
867
|
+
# Otherwise we have a default heuristic based on the attention implementation:
|
|
868
|
+
# attention implementations where an attention mask is needed suffer a lot more from the padding associated
|
|
869
|
+
# with cuda graphs, so default is to turn cuda graphs off for those implementations
|
|
870
|
+
use_cuda_graph = not attn_mask_is_needed(self.model.config)
|
|
871
|
+
logger.warning(
|
|
872
|
+
f"No behavior specified for use_cuda_graph, defaulting to {use_cuda_graph = } because "
|
|
873
|
+
f"{self.model.config._attn_implementation = }. If you want to save memory, turn off cuda graphs, but "
|
|
874
|
+
"they can improve performances."
|
|
875
|
+
)
|
|
876
|
+
return use_cuda_graph
|
|
877
|
+
|
|
799
878
|
@traced
|
|
800
879
|
def start(self) -> None:
|
|
801
880
|
"""Start the background generation thread."""
|
|
@@ -822,7 +901,7 @@ class ContinuousBatchingManager:
|
|
|
822
901
|
logger.warning("\nBatch processor was not initialized.")
|
|
823
902
|
else:
|
|
824
903
|
if self.batch_processor.cache.use_prefix_sharing:
|
|
825
|
-
logger.
|
|
904
|
+
logger.info(
|
|
826
905
|
f"\nPrefix sharing was on. Total prefix length: {self.batch_processor.cache._total_prefix_length}"
|
|
827
906
|
)
|
|
828
907
|
|
|
@@ -884,6 +963,7 @@ class ContinuousBatchingManager:
|
|
|
884
963
|
state = RequestState(
|
|
885
964
|
request_id=request_id,
|
|
886
965
|
initial_tokens=list(input_ids),
|
|
966
|
+
num_children=self.num_return_sequences - 1,
|
|
887
967
|
record_timestamps=record_timestamps,
|
|
888
968
|
tokens_to_process=list(input_ids),
|
|
889
969
|
max_new_tokens=max_new_tokens,
|
|
@@ -902,6 +982,10 @@ class ContinuousBatchingManager:
|
|
|
902
982
|
streaming: bool = False,
|
|
903
983
|
record_timestamps: bool = False,
|
|
904
984
|
) -> None:
|
|
985
|
+
# If there is prefix sharing, we sort the inputs to maximize cache hits
|
|
986
|
+
if self._use_prefix_sharing:
|
|
987
|
+
inputs = sorted(inputs, reverse=True)
|
|
988
|
+
# Add requests in order
|
|
905
989
|
for input_ids in inputs:
|
|
906
990
|
self.add_request(
|
|
907
991
|
input_ids, max_new_tokens=max_new_tokens, streaming=streaming, record_timestamps=record_timestamps
|
|
@@ -972,8 +1056,9 @@ class ContinuousBatchingManager:
|
|
|
972
1056
|
self.model.device,
|
|
973
1057
|
self.model.dtype,
|
|
974
1058
|
tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting
|
|
975
|
-
|
|
1059
|
+
allow_block_sharing=self._allow_block_sharing,
|
|
976
1060
|
)
|
|
1061
|
+
self._use_prefix_sharing = paged_attention_cache.use_prefix_sharing # update the approximation
|
|
977
1062
|
logger.debug(f"PagedAttentionCache created in {perf_counter() - t0} seconds")
|
|
978
1063
|
|
|
979
1064
|
scheduler = None
|
|
@@ -999,6 +1084,8 @@ class ContinuousBatchingManager:
|
|
|
999
1084
|
scheduler=scheduler(paged_attention_cache, self.manual_eviction),
|
|
1000
1085
|
manual_eviction=self.manual_eviction,
|
|
1001
1086
|
use_cuda_graph=self.use_cuda_graph,
|
|
1087
|
+
q_padding_intervals=self.q_padding_intervals,
|
|
1088
|
+
kv_padding_intervals=self.kv_padding_intervals,
|
|
1002
1089
|
)
|
|
1003
1090
|
self.batch_processor = batch_processor
|
|
1004
1091
|
self.current_batch = 0
|
|
@@ -1024,13 +1111,12 @@ class ContinuousBatchingManager:
|
|
|
1024
1111
|
# Debug logging of the current memory usage
|
|
1025
1112
|
if logger.level <= logging.DEBUG:
|
|
1026
1113
|
device, total, reserved, allocated = get_device_and_memory_breakdown()
|
|
1027
|
-
|
|
1114
|
+
available_memory = total - max(allocated, reserved)
|
|
1115
|
+
logger.debug(
|
|
1116
|
+
f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}, Available: {available_memory}"
|
|
1117
|
+
)
|
|
1028
1118
|
|
|
1029
1119
|
self._generation_step()
|
|
1030
|
-
|
|
1031
|
-
if torch.cuda.is_available():
|
|
1032
|
-
torch.cuda.synchronize()
|
|
1033
|
-
# Processor updates the batch after generation step is truly over
|
|
1034
1120
|
batch_processor.update_batch()
|
|
1035
1121
|
|
|
1036
1122
|
@traced
|
|
@@ -1072,7 +1158,7 @@ class ContinuousMixin:
|
|
|
1072
1158
|
max_queue_size: int = 0,
|
|
1073
1159
|
num_q_cuda_graphs: int = 0,
|
|
1074
1160
|
num_kv_cuda_graphs: int = 0,
|
|
1075
|
-
|
|
1161
|
+
allow_block_sharing: bool = True,
|
|
1076
1162
|
block: bool = True,
|
|
1077
1163
|
timeout: float | None = None,
|
|
1078
1164
|
) -> Generator[ContinuousBatchingManager]:
|
|
@@ -1082,7 +1168,7 @@ class ContinuousMixin:
|
|
|
1082
1168
|
max_queue_size,
|
|
1083
1169
|
num_q_cuda_graphs,
|
|
1084
1170
|
num_kv_cuda_graphs,
|
|
1085
|
-
|
|
1171
|
+
allow_block_sharing,
|
|
1086
1172
|
)
|
|
1087
1173
|
manager.start()
|
|
1088
1174
|
try:
|
|
@@ -1099,18 +1185,19 @@ class ContinuousMixin:
|
|
|
1099
1185
|
generation_config: GenerationConfig | None = None,
|
|
1100
1186
|
manual_eviction: bool = False,
|
|
1101
1187
|
max_queue_size: int = 0,
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1188
|
+
num_q_padding_intervals: int = 0,
|
|
1189
|
+
num_kv_padding_intervals: int = 0,
|
|
1190
|
+
allow_block_sharing: bool = True,
|
|
1105
1191
|
) -> ContinuousBatchingManager:
|
|
1106
1192
|
"""Initialize a manager for continuous batching inference.
|
|
1107
1193
|
|
|
1108
1194
|
Args:
|
|
1109
|
-
generation_config:
|
|
1195
|
+
generation_config: An optional generation configuration, which may contain a CompileConfig object
|
|
1110
1196
|
manual_eviction: Whether to manually evict requests from the cache
|
|
1111
1197
|
max_queue_size: Maximum size of the input request queue
|
|
1112
|
-
|
|
1113
|
-
|
|
1198
|
+
num_q_padding_intervals: Number of intervals used to pad the query dimension
|
|
1199
|
+
num_kv_padding_intervals: Number of intervals used to pad the keys/values dimension
|
|
1200
|
+
allow_block_sharing: A flag to allow block sharing if the model has some full attention layers
|
|
1114
1201
|
|
|
1115
1202
|
Returns:
|
|
1116
1203
|
`ContinuousBatchingManager`: The manager instance to add requests and retrieve results.
|
|
@@ -1132,9 +1219,9 @@ class ContinuousMixin:
|
|
|
1132
1219
|
generation_config=gen_config,
|
|
1133
1220
|
manual_eviction=manual_eviction,
|
|
1134
1221
|
max_queue_size=max_queue_size,
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1222
|
+
num_q_padding_intervals=num_q_padding_intervals,
|
|
1223
|
+
num_kv_padding_intervals=num_kv_padding_intervals,
|
|
1224
|
+
allow_block_sharing=allow_block_sharing,
|
|
1138
1225
|
)
|
|
1139
1226
|
|
|
1140
1227
|
# TODO: support streaming
|
|
@@ -1144,11 +1231,11 @@ class ContinuousMixin:
|
|
|
1144
1231
|
self,
|
|
1145
1232
|
inputs: list[list[int]],
|
|
1146
1233
|
generation_config: GenerationConfig | None = None,
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
allow_prefix_sharing: bool = True,
|
|
1234
|
+
num_q_padding_intervals: int = 0,
|
|
1235
|
+
num_kv_padding_intervals: int = 0,
|
|
1236
|
+
allow_block_sharing: bool = True,
|
|
1151
1237
|
record_timestamps: bool = False,
|
|
1238
|
+
progress_bar: bool = True,
|
|
1152
1239
|
**kwargs,
|
|
1153
1240
|
) -> dict[str, GenerationOutput]:
|
|
1154
1241
|
"""Generate sequences for a batch of prompts using continuous batching.
|
|
@@ -1156,14 +1243,15 @@ class ContinuousMixin:
|
|
|
1156
1243
|
Args:
|
|
1157
1244
|
inputs: List of input token sequences (prompts)
|
|
1158
1245
|
generation_config: Optional generation configuration
|
|
1159
|
-
|
|
1160
|
-
|
|
1246
|
+
num_q_padding_intervals: Number of intervals used to pad the query dimension
|
|
1247
|
+
num_kv_padding_intervals: Number of intervals used to pad the keys/values dimension
|
|
1248
|
+
allow_block_sharing: A flag to allow block sharing if the model has some full attention layers
|
|
1249
|
+
record_timestamps: If set to true, the requests will have a timestamp for each token generated
|
|
1250
|
+
progress_bar: If set to true, a progress bar will be displayed
|
|
1161
1251
|
**kwargs: Additional generation parameters
|
|
1162
1252
|
|
|
1163
1253
|
Returns:
|
|
1164
|
-
`
|
|
1165
|
-
if not handled otherwise) for each input prompt, in the same order.
|
|
1166
|
-
Returns an empty list `[]` for requests that failed.
|
|
1254
|
+
`dict[str, GenerationOutput]`: a dictionary of request ids to GenerationOutput objects
|
|
1167
1255
|
"""
|
|
1168
1256
|
if not inputs:
|
|
1169
1257
|
return {}
|
|
@@ -1173,26 +1261,30 @@ class ContinuousMixin:
|
|
|
1173
1261
|
|
|
1174
1262
|
# Initialize manager with the batch inputs
|
|
1175
1263
|
results = {}
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
)
|
|
1264
|
+
gen_cfg = self.generation_config if generation_config is None else generation_config
|
|
1265
|
+
num_requests = len(inputs) * gen_cfg.num_return_sequences
|
|
1266
|
+
# Prepare context managers for the main loop
|
|
1267
|
+
manager_cm = self.continuous_batching_context_manager(
|
|
1268
|
+
generation_config=generation_config,
|
|
1269
|
+
num_q_cuda_graphs=num_q_padding_intervals,
|
|
1270
|
+
num_kv_cuda_graphs=num_kv_padding_intervals,
|
|
1271
|
+
allow_block_sharing=allow_block_sharing,
|
|
1272
|
+
block=True,
|
|
1273
|
+
timeout=5,
|
|
1274
|
+
)
|
|
1275
|
+
logging_cm = logging_redirect_tqdm([logger])
|
|
1276
|
+
pbar_cm = tqdm(
|
|
1277
|
+
total=num_requests,
|
|
1278
|
+
disable=(not progress_bar),
|
|
1279
|
+
desc=f"Solving {num_requests} requests",
|
|
1280
|
+
unit="request",
|
|
1281
|
+
)
|
|
1282
|
+
# Main loop
|
|
1283
|
+
with manager_cm as manager, logging_cm, pbar_cm as pbar:
|
|
1194
1284
|
try:
|
|
1195
|
-
manager.add_requests(
|
|
1285
|
+
manager.add_requests(
|
|
1286
|
+
inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"), record_timestamps=record_timestamps
|
|
1287
|
+
)
|
|
1196
1288
|
finished_count = 0
|
|
1197
1289
|
while finished_count < num_requests:
|
|
1198
1290
|
result = manager.get_result(timeout=1)
|