transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -105,8 +105,9 @@ class GenerationConfig(PushToHubMixin):
|
|
|
105
105
|
> Parameters that control the length of the output
|
|
106
106
|
|
|
107
107
|
max_length (`int`, *optional*, defaults to 20):
|
|
108
|
-
|
|
109
|
-
`
|
|
108
|
+
`max_new_tokens` is recommended for controlling how many tokens the model generates.
|
|
109
|
+
`max_length` remains for backward compatibility.
|
|
110
|
+
|
|
110
111
|
max_new_tokens (`int`, *optional*):
|
|
111
112
|
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
|
112
113
|
min_length (`int`, *optional*, defaults to 0):
|
|
@@ -15,12 +15,16 @@
|
|
|
15
15
|
from .cache import PagedAttentionCache
|
|
16
16
|
from .continuous_api import ContinuousBatchingManager, ContinuousMixin
|
|
17
17
|
from .requests import RequestState, RequestStatus
|
|
18
|
+
from .scheduler import FIFOScheduler, PrefillFirstScheduler, Scheduler
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
__all__ = [
|
|
21
22
|
"ContinuousBatchingManager",
|
|
22
23
|
"ContinuousMixin",
|
|
24
|
+
"FIFOScheduler",
|
|
23
25
|
"PagedAttentionCache",
|
|
26
|
+
"PrefillFirstScheduler",
|
|
24
27
|
"RequestState",
|
|
25
28
|
"RequestStatus",
|
|
29
|
+
"Scheduler",
|
|
26
30
|
]
|
|
@@ -29,7 +29,7 @@ from tqdm import tqdm
|
|
|
29
29
|
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
30
30
|
|
|
31
31
|
from ...configuration_utils import PretrainedConfig
|
|
32
|
-
from ...generation.configuration_utils import GenerationConfig
|
|
32
|
+
from ...generation.configuration_utils import CompileConfig, GenerationConfig
|
|
33
33
|
from ...generation.logits_process import LogitsProcessor
|
|
34
34
|
from ...utils.logging import logging
|
|
35
35
|
from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
|
|
@@ -45,17 +45,20 @@ generation goes on, there are two dimensions that change:
|
|
|
45
45
|
- the number of keys/values tokens (KV), which grows as the cache does
|
|
46
46
|
|
|
47
47
|
To solve this, we slice along those dimensions to fixed lengths. The size of the slices is controlled by the variables
|
|
48
|
-
|
|
49
|
-
number of queries tokens is 1000, and
|
|
50
|
-
1000 / 4 = 250 tokens, ie. to 250, 500, 750 or 1000 queries tokens.
|
|
48
|
+
num_x_padding_intervals: NUM_X_PADDING_INTERVALS means that we create at most NUM_X_PADDING_INTERVALS graphs for the X
|
|
49
|
+
dimension. So if the maximum number of queries tokens is 1000, and NUM_Q_PADDING_INTERVALS is 4, we will slice the
|
|
50
|
+
number of queries token by intervals of 1000 / 4 = 250 tokens, ie. to 250, 500, 750 or 1000 queries tokens.
|
|
51
51
|
|
|
52
52
|
Smaller slices means more granularity and thus less padding. But since each graph takes up space on the GPU and time to
|
|
53
53
|
create, we don't want to many graphs. And since the size of the KV dimension is the number of queries tokens plus the
|
|
54
54
|
number of tokens cached, dimension of KV is usually much larger than the dimension of Q. So we have more granularity
|
|
55
55
|
for the KV dimension than the query dimension.
|
|
56
|
+
|
|
57
|
+
This variable used to be called NUM_X_CUDA_GRAPHS, but we renamed it to NUM_X_PADDING_INTERVALS because it is used for
|
|
58
|
+
padding in the case of cuda graphs AND torch.compile.
|
|
56
59
|
"""
|
|
57
|
-
|
|
58
|
-
|
|
60
|
+
NUM_Q_PADDING_INTERVALS = 4
|
|
61
|
+
NUM_KV_PADDING_INTERVALS = 8
|
|
59
62
|
|
|
60
63
|
|
|
61
64
|
def pad_by_intervals(size: int, max_value: int, nb_intervals: int) -> int:
|
|
@@ -63,7 +66,7 @@ def pad_by_intervals(size: int, max_value: int, nb_intervals: int) -> int:
|
|
|
63
66
|
interval_size = max_value // nb_intervals
|
|
64
67
|
if interval_size == 0:
|
|
65
68
|
return max_value
|
|
66
|
-
padded = ceil(size / interval_size) * interval_size
|
|
69
|
+
padded = ceil(size / interval_size) * interval_size if size > 0 else interval_size
|
|
67
70
|
return min(padded, max_value)
|
|
68
71
|
|
|
69
72
|
|
|
@@ -188,6 +191,8 @@ class ContinuousBatchProcessor:
|
|
|
188
191
|
scheduler: Scheduler,
|
|
189
192
|
manual_eviction: bool,
|
|
190
193
|
use_cuda_graph: bool,
|
|
194
|
+
q_padding_intervals: int,
|
|
195
|
+
kv_padding_intervals: int,
|
|
191
196
|
) -> None:
|
|
192
197
|
"""Initialize the continuous batch processor.
|
|
193
198
|
|
|
@@ -221,7 +226,14 @@ class ContinuousBatchProcessor:
|
|
|
221
226
|
# Accumulator for batch scheduling
|
|
222
227
|
self.requests_in_batch: list[RequestState] = []
|
|
223
228
|
# Cuda graphs for the generation step
|
|
229
|
+
self.q_padding_intervals = q_padding_intervals
|
|
230
|
+
self.kv_padding_intervals = kv_padding_intervals
|
|
224
231
|
self._graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] | None = {} if use_cuda_graph else None
|
|
232
|
+
# Compile-related arguments
|
|
233
|
+
self.compile_config: CompileConfig | None = getattr(generation_config, "compile_config", None)
|
|
234
|
+
self._forward_process_and_sample_is_compiled = False
|
|
235
|
+
|
|
236
|
+
self._pad_inputs = use_cuda_graph or (self.compile_config is not None and not self.compile_config.dynamic)
|
|
225
237
|
|
|
226
238
|
# Set up metrics collector
|
|
227
239
|
self.max_batch_tokens = cache.max_batch_tokens
|
|
@@ -627,28 +639,39 @@ class ContinuousBatchProcessor:
|
|
|
627
639
|
def _generation_step(self, model: nn.Module, logit_processor: LogitsProcessor, do_sample: bool) -> None:
|
|
628
640
|
"""Perform a single generation step."""
|
|
629
641
|
|
|
630
|
-
# If
|
|
642
|
+
# If a compile config is specified, we compile the forward pass once in a wrapper
|
|
643
|
+
if self.compile_config is not None and not self._forward_process_and_sample_is_compiled:
|
|
644
|
+
self._forward_process_and_sample = torch.compile(
|
|
645
|
+
self._forward_process_and_sample,
|
|
646
|
+
fullgraph=self.compile_config.fullgraph,
|
|
647
|
+
mode=self.compile_config.mode,
|
|
648
|
+
dynamic=self.compile_config.dynamic,
|
|
649
|
+
backend=self.compile_config.backend,
|
|
650
|
+
options=self.compile_config.options,
|
|
651
|
+
)
|
|
652
|
+
self._forward_process_and_sample_is_compiled = True
|
|
653
|
+
|
|
654
|
+
# If inputs are static sized, we find the padded sizes of the queries and keys/values
|
|
655
|
+
if self._pad_inputs:
|
|
656
|
+
padded_q = pad_by_intervals(self.actual_query_length, self.max_batch_tokens, self.q_padding_intervals)
|
|
657
|
+
max_read_index_size = max(self.actual_index_sizes[i][0] for i in range(self.cache.num_groups))
|
|
658
|
+
padded_read_index_size = pad_by_intervals(
|
|
659
|
+
max_read_index_size - self.max_batch_tokens,
|
|
660
|
+
self.cache.num_blocks * self.cache.block_size,
|
|
661
|
+
self.kv_padding_intervals,
|
|
662
|
+
)
|
|
663
|
+
else:
|
|
664
|
+
padded_q, padded_read_index_size = 0, 0
|
|
665
|
+
# Retrieve the model kwargs with or without padding
|
|
666
|
+
batch_data = self.get_model_kwargs(padded_q, padded_read_index_size)
|
|
667
|
+
|
|
668
|
+
# If we are not using cuda graphs, we perform the generation step and return
|
|
631
669
|
if self._graphs is None:
|
|
632
|
-
batch_data = self.get_model_kwargs()
|
|
633
670
|
self._forward_process_and_sample(model, batch_data, logit_processor, do_sample)
|
|
634
671
|
return None
|
|
635
672
|
|
|
636
|
-
# Determine the padded size of the queries and keys/values
|
|
637
|
-
padded_q = pad_by_intervals(self.actual_query_length, self.max_batch_tokens, NUM_Q_CUDA_GRAPHS)
|
|
638
|
-
|
|
639
|
-
max_read_index_size = max(self.actual_index_sizes[i][0] for i in range(self.cache.num_groups))
|
|
640
|
-
padded_read_index_size = pad_by_intervals(
|
|
641
|
-
max_read_index_size - self.max_batch_tokens,
|
|
642
|
-
self.cache.num_blocks * self.cache.block_size,
|
|
643
|
-
NUM_KV_CUDA_GRAPHS,
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
# Get the batch data and the associated graph
|
|
647
|
-
batch_data = self.get_model_kwargs(padded_q, padded_read_index_size)
|
|
648
|
-
|
|
649
|
-
graph = self._graphs.get((padded_q, padded_read_index_size))
|
|
650
|
-
|
|
651
673
|
# If we have a graph that fits, we replay it
|
|
674
|
+
graph = self._graphs.get((padded_q, padded_read_index_size))
|
|
652
675
|
if graph is not None:
|
|
653
676
|
graph.replay()
|
|
654
677
|
return None
|
|
@@ -673,7 +696,6 @@ class ContinuousBatchProcessor:
|
|
|
673
696
|
) -> None:
|
|
674
697
|
"""This function performs the forward pass, logits processing, and sampling; which are broken down into smaller
|
|
675
698
|
function to be easier to trace with OpenTelemetry."""
|
|
676
|
-
# with torch.no_grad():
|
|
677
699
|
logits = self._model_forward(model, batch_data)
|
|
678
700
|
# if self.log_prob_generation: batch_processor.output_probs.copy_(logits) # TODO
|
|
679
701
|
probs = self._process_logit(batch_data, logits, logit_processor)
|
|
@@ -691,6 +713,7 @@ class ContinuousBatchProcessor:
|
|
|
691
713
|
# Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size]
|
|
692
714
|
# but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size]
|
|
693
715
|
batch_size, seq_len, vocab_size = logits.shape
|
|
716
|
+
# NOTE: to be an exact match with generate, we should also convert logits2d to float32 here, but it's not needed in practice
|
|
694
717
|
logits_2d = logits.view(batch_size * seq_len, vocab_size)
|
|
695
718
|
input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len)
|
|
696
719
|
# Process with 2D tensors
|
|
@@ -727,8 +750,8 @@ class ContinuousBatchingManager:
|
|
|
727
750
|
generation_config: GenerationConfig,
|
|
728
751
|
manual_eviction: bool = False,
|
|
729
752
|
max_queue_size: int = 0,
|
|
730
|
-
|
|
731
|
-
|
|
753
|
+
num_q_padding_intervals: int = 0,
|
|
754
|
+
num_kv_padding_intervals: int = 0,
|
|
732
755
|
allow_prefix_sharing: bool = True,
|
|
733
756
|
) -> None:
|
|
734
757
|
"""Initialize the continuous batching manager.
|
|
@@ -737,19 +760,13 @@ class ContinuousBatchingManager:
|
|
|
737
760
|
model: The language model for generation
|
|
738
761
|
generation_config: Configuration for generation parameters
|
|
739
762
|
max_queue_size: Maximum size of the request queue (0 = unlimited)
|
|
740
|
-
|
|
741
|
-
|
|
763
|
+
num_q_padding_intervals: (optional) Number of intervals used to pad the query dimension
|
|
764
|
+
num_kv_padding_intervals: (optional) Number of intervals used to pad the keys/values dimension
|
|
742
765
|
allow_prefix_sharing: (optional) Whether to allow prefix sharing if the model has only full attention layers
|
|
743
766
|
"""
|
|
767
|
+
# Reloade paged version if necessary
|
|
744
768
|
if "paged|" not in model.config._attn_implementation:
|
|
745
|
-
|
|
746
|
-
model.config._attn_implementation = attn_implementation
|
|
747
|
-
|
|
748
|
-
# lazy loading flash attention including kernel variations
|
|
749
|
-
if "flash" in attn_implementation:
|
|
750
|
-
from ...modeling_flash_attention_utils import lazy_import_paged_flash_attention
|
|
751
|
-
|
|
752
|
-
lazy_import_paged_flash_attention(attn_implementation)
|
|
769
|
+
model.set_attn_implementation(f"paged|{model.config._attn_implementation}")
|
|
753
770
|
|
|
754
771
|
self.model = model.eval()
|
|
755
772
|
generation_config = model.generation_config if generation_config is None else generation_config
|
|
@@ -764,38 +781,69 @@ class ContinuousBatchingManager:
|
|
|
764
781
|
self.model.generation_config.top_p = None
|
|
765
782
|
self.do_sample = getattr(generation_config, "do_sample", True)
|
|
766
783
|
self.logit_processor = self.model._get_logits_processor(generation_config)
|
|
767
|
-
use_cuda_graph: bool | None = getattr(generation_config, "use_cuda_graph", None)
|
|
768
784
|
self.profile = getattr(generation_config, "profile", False) # TODO: not supported yet
|
|
769
785
|
self.manual_eviction = manual_eviction
|
|
770
786
|
self.batch_processor: ContinuousBatchProcessor | None = None
|
|
771
|
-
|
|
772
787
|
self._allow_prefix_sharing = allow_prefix_sharing
|
|
773
788
|
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
# If the use of cuda graphs is not specified, we follow the user's choice, otherwise we have a default heuristic
|
|
781
|
-
else:
|
|
782
|
-
# Attention implementations where an attention mask is needed suffer a lot more from the padding associated
|
|
783
|
-
# with cuda graphs, so default is to turn cuda graphs off for those implementations
|
|
784
|
-
self.use_cuda_graph = not attn_mask_is_needed(self.model.config)
|
|
785
|
-
logger.warning(
|
|
786
|
-
f"No behavior specified for use_cuda_graph, defaulting to {self.use_cuda_graph = } because "
|
|
787
|
-
f"{self.model.config._attn_implementation = }. If you want to save memory, turn off cuda graphs, but "
|
|
788
|
-
"they can improve performances."
|
|
789
|
-
)
|
|
789
|
+
self.use_cuda_graph = self._decide_use_cuda_graphs(
|
|
790
|
+
use_cuda_graph=getattr(generation_config, "use_cuda_graph", None),
|
|
791
|
+
num_q_padding_intervals=num_q_padding_intervals,
|
|
792
|
+
num_kv_padding_intervals=num_kv_padding_intervals,
|
|
793
|
+
compile_config=getattr(generation_config, "compile_config", None),
|
|
794
|
+
)
|
|
790
795
|
|
|
791
|
-
#
|
|
792
|
-
if
|
|
793
|
-
|
|
794
|
-
|
|
796
|
+
# We set the number of padding intervals for Q and KV
|
|
797
|
+
self.q_padding_intervals = num_q_padding_intervals if num_q_padding_intervals > 0 else NUM_Q_PADDING_INTERVALS
|
|
798
|
+
self.kv_padding_intervals = (
|
|
799
|
+
num_kv_padding_intervals if num_kv_padding_intervals > 0 else NUM_KV_PADDING_INTERVALS
|
|
800
|
+
)
|
|
795
801
|
|
|
796
802
|
if self.log_prob_generation:
|
|
797
803
|
raise NotImplementedError("log_prob_generation is not supported yet")
|
|
798
804
|
|
|
805
|
+
def _decide_use_cuda_graphs(
|
|
806
|
+
self,
|
|
807
|
+
use_cuda_graph: bool | None,
|
|
808
|
+
num_q_padding_intervals: int,
|
|
809
|
+
num_kv_padding_intervals: int,
|
|
810
|
+
compile_config: CompileConfig | None,
|
|
811
|
+
) -> bool:
|
|
812
|
+
"""Returns whether or not to use cuda graphs for continuous batching, depending on the following criteria:
|
|
813
|
+
- (use_cuda_graph) which is the user choice
|
|
814
|
+
- (num_q_padding_intervals) or (num_kv_padding_intervals) which is used to pad inputs: if it was specified by
|
|
815
|
+
the user, it's probable they want to use cuda graphs so inputs need to be padded
|
|
816
|
+
- (compile_config): if compile is on, turn on cuda graphs unless the compile mode uses its own cudagraphs
|
|
817
|
+
If none of the above criteria are met, we use a default heuristic based on the attention implementation: we turn
|
|
818
|
+
on cuda graphs if and only if no attention mask is needed.
|
|
819
|
+
"""
|
|
820
|
+
# If use_cuda_graph is specified, we follow the user's choice
|
|
821
|
+
if use_cuda_graph is not None:
|
|
822
|
+
return use_cuda_graph
|
|
823
|
+
# If a number of padding intervals was specified for either Q or KV, we activate cuda graphs
|
|
824
|
+
if num_q_padding_intervals > 0 or num_kv_padding_intervals > 0:
|
|
825
|
+
return True
|
|
826
|
+
# If a compile config was found, turn off cuda graphs if the compile config already uses them
|
|
827
|
+
if compile_config is not None:
|
|
828
|
+
options = torch._inductor.list_mode_options().get(compile_config.mode, compile_config.options)
|
|
829
|
+
compile_uses_cudagraphs = options.get("triton.cudagraphs", False)
|
|
830
|
+
if compile_uses_cudagraphs:
|
|
831
|
+
logger.warning(
|
|
832
|
+
f"Compile config {compile_config.mode = } uses cudagraphs, which usually does not work well with "
|
|
833
|
+
"continuous batching. We recommend using mode 'default' or 'max-autotune-no-cudagraphs' instead."
|
|
834
|
+
)
|
|
835
|
+
return not compile_uses_cudagraphs # TODO: should this also match the dynamic shapes?
|
|
836
|
+
# Otherwise we have a default heuristic based on the attention implementation:
|
|
837
|
+
# attention implementations where an attention mask is needed suffer a lot more from the padding associated
|
|
838
|
+
# with cuda graphs, so default is to turn cuda graphs off for those implementations
|
|
839
|
+
use_cuda_graph = not attn_mask_is_needed(self.model.config)
|
|
840
|
+
logger.warning(
|
|
841
|
+
f"No behavior specified for use_cuda_graph, defaulting to {use_cuda_graph = } because "
|
|
842
|
+
f"{self.model.config._attn_implementation = }. If you want to save memory, turn off cuda graphs, but "
|
|
843
|
+
"they can improve performances."
|
|
844
|
+
)
|
|
845
|
+
return use_cuda_graph
|
|
846
|
+
|
|
799
847
|
@traced
|
|
800
848
|
def start(self) -> None:
|
|
801
849
|
"""Start the background generation thread."""
|
|
@@ -822,7 +870,7 @@ class ContinuousBatchingManager:
|
|
|
822
870
|
logger.warning("\nBatch processor was not initialized.")
|
|
823
871
|
else:
|
|
824
872
|
if self.batch_processor.cache.use_prefix_sharing:
|
|
825
|
-
logger.
|
|
873
|
+
logger.info(
|
|
826
874
|
f"\nPrefix sharing was on. Total prefix length: {self.batch_processor.cache._total_prefix_length}"
|
|
827
875
|
)
|
|
828
876
|
|
|
@@ -999,6 +1047,8 @@ class ContinuousBatchingManager:
|
|
|
999
1047
|
scheduler=scheduler(paged_attention_cache, self.manual_eviction),
|
|
1000
1048
|
manual_eviction=self.manual_eviction,
|
|
1001
1049
|
use_cuda_graph=self.use_cuda_graph,
|
|
1050
|
+
q_padding_intervals=self.q_padding_intervals,
|
|
1051
|
+
kv_padding_intervals=self.kv_padding_intervals,
|
|
1002
1052
|
)
|
|
1003
1053
|
self.batch_processor = batch_processor
|
|
1004
1054
|
self.current_batch = 0
|
|
@@ -1024,12 +1074,15 @@ class ContinuousBatchingManager:
|
|
|
1024
1074
|
# Debug logging of the current memory usage
|
|
1025
1075
|
if logger.level <= logging.DEBUG:
|
|
1026
1076
|
device, total, reserved, allocated = get_device_and_memory_breakdown()
|
|
1027
|
-
|
|
1077
|
+
available_memory = total - max(allocated, reserved)
|
|
1078
|
+
logger.debug(
|
|
1079
|
+
f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}, Available: {available_memory}"
|
|
1080
|
+
)
|
|
1028
1081
|
|
|
1029
1082
|
self._generation_step()
|
|
1030
1083
|
|
|
1031
1084
|
if torch.cuda.is_available():
|
|
1032
|
-
torch.cuda.synchronize()
|
|
1085
|
+
torch.cuda.synchronize() # FIXME: why is this needed?
|
|
1033
1086
|
# Processor updates the batch after generation step is truly over
|
|
1034
1087
|
batch_processor.update_batch()
|
|
1035
1088
|
|
|
@@ -1099,18 +1152,19 @@ class ContinuousMixin:
|
|
|
1099
1152
|
generation_config: GenerationConfig | None = None,
|
|
1100
1153
|
manual_eviction: bool = False,
|
|
1101
1154
|
max_queue_size: int = 0,
|
|
1102
|
-
|
|
1103
|
-
|
|
1155
|
+
num_q_padding_intervals: int = 0,
|
|
1156
|
+
num_kv_padding_intervals: int = 0,
|
|
1104
1157
|
allow_prefix_sharing: bool = True,
|
|
1105
1158
|
) -> ContinuousBatchingManager:
|
|
1106
1159
|
"""Initialize a manager for continuous batching inference.
|
|
1107
1160
|
|
|
1108
1161
|
Args:
|
|
1109
|
-
generation_config:
|
|
1162
|
+
generation_config: An optional generation configuration, which may contain a CompileConfig object
|
|
1110
1163
|
manual_eviction: Whether to manually evict requests from the cache
|
|
1111
1164
|
max_queue_size: Maximum size of the input request queue
|
|
1112
|
-
|
|
1113
|
-
|
|
1165
|
+
num_q_padding_intervals: Number of intervals used to pad the query dimension
|
|
1166
|
+
num_kv_padding_intervals: Number of intervals used to pad the keys/values dimension
|
|
1167
|
+
allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers
|
|
1114
1168
|
|
|
1115
1169
|
Returns:
|
|
1116
1170
|
`ContinuousBatchingManager`: The manager instance to add requests and retrieve results.
|
|
@@ -1132,8 +1186,8 @@ class ContinuousMixin:
|
|
|
1132
1186
|
generation_config=gen_config,
|
|
1133
1187
|
manual_eviction=manual_eviction,
|
|
1134
1188
|
max_queue_size=max_queue_size,
|
|
1135
|
-
|
|
1136
|
-
|
|
1189
|
+
num_q_padding_intervals=num_q_padding_intervals,
|
|
1190
|
+
num_kv_padding_intervals=num_kv_padding_intervals,
|
|
1137
1191
|
allow_prefix_sharing=allow_prefix_sharing,
|
|
1138
1192
|
)
|
|
1139
1193
|
|
|
@@ -1144,11 +1198,11 @@ class ContinuousMixin:
|
|
|
1144
1198
|
self,
|
|
1145
1199
|
inputs: list[list[int]],
|
|
1146
1200
|
generation_config: GenerationConfig | None = None,
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
num_kv_cuda_graphs: int = 0,
|
|
1201
|
+
num_q_padding_intervals: int = 0,
|
|
1202
|
+
num_kv_padding_intervals: int = 0,
|
|
1150
1203
|
allow_prefix_sharing: bool = True,
|
|
1151
1204
|
record_timestamps: bool = False,
|
|
1205
|
+
progress_bar: bool = True,
|
|
1152
1206
|
**kwargs,
|
|
1153
1207
|
) -> dict[str, GenerationOutput]:
|
|
1154
1208
|
"""Generate sequences for a batch of prompts using continuous batching.
|
|
@@ -1156,14 +1210,15 @@ class ContinuousMixin:
|
|
|
1156
1210
|
Args:
|
|
1157
1211
|
inputs: List of input token sequences (prompts)
|
|
1158
1212
|
generation_config: Optional generation configuration
|
|
1159
|
-
|
|
1160
|
-
|
|
1213
|
+
num_q_padding_intervals: Number of intervals used to pad the query dimension
|
|
1214
|
+
num_kv_padding_intervals: Number of intervals used to pad the keys/values dimension
|
|
1215
|
+
allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers
|
|
1216
|
+
record_timestamps: If set to true, the requests will have a timestamp for each token generated
|
|
1217
|
+
progress_bar: If set to true, a progress bar will be displayed
|
|
1161
1218
|
**kwargs: Additional generation parameters
|
|
1162
1219
|
|
|
1163
1220
|
Returns:
|
|
1164
|
-
`
|
|
1165
|
-
if not handled otherwise) for each input prompt, in the same order.
|
|
1166
|
-
Returns an empty list `[]` for requests that failed.
|
|
1221
|
+
`dict[str, GenerationOutput]`: a dictionary of request ids to GenerationOutput objects
|
|
1167
1222
|
"""
|
|
1168
1223
|
if not inputs:
|
|
1169
1224
|
return {}
|
|
@@ -1177,8 +1232,8 @@ class ContinuousMixin:
|
|
|
1177
1232
|
with (
|
|
1178
1233
|
self.continuous_batching_context_manager(
|
|
1179
1234
|
generation_config=generation_config,
|
|
1180
|
-
num_q_cuda_graphs=
|
|
1181
|
-
num_kv_cuda_graphs=
|
|
1235
|
+
num_q_cuda_graphs=num_q_padding_intervals,
|
|
1236
|
+
num_kv_cuda_graphs=num_kv_padding_intervals,
|
|
1182
1237
|
allow_prefix_sharing=allow_prefix_sharing,
|
|
1183
1238
|
block=True,
|
|
1184
1239
|
timeout=5,
|
|
@@ -18,7 +18,7 @@ import os
|
|
|
18
18
|
from typing import Any, Optional, TypeVar, Union
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
|
-
from huggingface_hub import create_repo
|
|
21
|
+
from huggingface_hub import create_repo, is_offline_mode
|
|
22
22
|
|
|
23
23
|
from .dynamic_module_utils import custom_object_save
|
|
24
24
|
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
|
|
@@ -28,7 +28,6 @@ from .utils import (
|
|
|
28
28
|
PROCESSOR_NAME,
|
|
29
29
|
PushToHubMixin,
|
|
30
30
|
copy_func,
|
|
31
|
-
is_offline_mode,
|
|
32
31
|
logging,
|
|
33
32
|
safe_load_json_file,
|
|
34
33
|
)
|
|
@@ -19,7 +19,6 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_availa
|
|
|
19
19
|
_import_structure = {
|
|
20
20
|
"aqlm": ["replace_with_aqlm_linear"],
|
|
21
21
|
"awq": [
|
|
22
|
-
"fuse_awq_modules",
|
|
23
22
|
"post_init_awq_exllama_modules",
|
|
24
23
|
"post_init_awq_ipex_modules",
|
|
25
24
|
"replace_quantization_scales",
|
|
@@ -54,6 +53,7 @@ _import_structure = {
|
|
|
54
53
|
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
|
|
55
54
|
"fsdp": ["is_fsdp_enabled", "is_fsdp_managed_module"],
|
|
56
55
|
"ggml": [
|
|
56
|
+
"GGUF_CONFIG_DEFAULTS_MAPPING",
|
|
57
57
|
"GGUF_CONFIG_MAPPING",
|
|
58
58
|
"GGUF_TOKENIZER_MAPPING",
|
|
59
59
|
"_gguf_parse_value",
|
|
@@ -73,6 +73,7 @@ _import_structure = {
|
|
|
73
73
|
"replace_kernel_forward_from_hub",
|
|
74
74
|
"use_kernel_forward_from_hub",
|
|
75
75
|
"use_kernel_func_from_hub",
|
|
76
|
+
"use_kernelized_func",
|
|
76
77
|
],
|
|
77
78
|
"integration_utils": [
|
|
78
79
|
"INTEGRATION_TO_CALLBACK",
|
|
@@ -165,7 +166,6 @@ else:
|
|
|
165
166
|
if TYPE_CHECKING:
|
|
166
167
|
from .aqlm import replace_with_aqlm_linear
|
|
167
168
|
from .awq import (
|
|
168
|
-
fuse_awq_modules,
|
|
169
169
|
post_init_awq_exllama_modules,
|
|
170
170
|
post_init_awq_ipex_modules,
|
|
171
171
|
replace_quantization_scales,
|
|
@@ -200,6 +200,7 @@ if TYPE_CHECKING:
|
|
|
200
200
|
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
|
|
201
201
|
from .fsdp import is_fsdp_enabled, is_fsdp_managed_module
|
|
202
202
|
from .ggml import (
|
|
203
|
+
GGUF_CONFIG_DEFAULTS_MAPPING,
|
|
203
204
|
GGUF_CONFIG_MAPPING,
|
|
204
205
|
GGUF_TOKENIZER_MAPPING,
|
|
205
206
|
_gguf_parse_value,
|
|
@@ -214,6 +215,7 @@ if TYPE_CHECKING:
|
|
|
214
215
|
replace_kernel_forward_from_hub,
|
|
215
216
|
use_kernel_forward_from_hub,
|
|
216
217
|
use_kernel_func_from_hub,
|
|
218
|
+
use_kernelized_func,
|
|
217
219
|
)
|
|
218
220
|
from .integration_utils import (
|
|
219
221
|
INTEGRATION_TO_CALLBACK,
|
|
@@ -392,6 +392,15 @@ def _get_device_map(
|
|
|
392
392
|
)
|
|
393
393
|
else:
|
|
394
394
|
inferred_max_memory = get_max_memory(max_memory)
|
|
395
|
+
|
|
396
|
+
# If the user does not provide `max_memory`, accelerate sets the WHOLE cpu available memory as available.
|
|
397
|
+
# This is unwanted, as we don't want to set extremely tight bound and pressure for cpu if we are memory-constrained,
|
|
398
|
+
# especially if the model uses WeightConverter (because there will be some uncontrollable cpu memory spikes during
|
|
399
|
+
# the conversions before we resave the weights). In those cases, it's better to offload to disk a bit more
|
|
400
|
+
# if we were in-between, as otherwise we blow-up cpu memory
|
|
401
|
+
if max_memory is None:
|
|
402
|
+
inferred_max_memory["cpu"] *= 0.90
|
|
403
|
+
|
|
395
404
|
if hf_quantizer is not None:
|
|
396
405
|
inferred_max_memory = hf_quantizer.adjust_max_memory(inferred_max_memory)
|
|
397
406
|
|
|
@@ -466,10 +475,10 @@ def expand_device_map(device_map, param_names):
|
|
|
466
475
|
|
|
467
476
|
|
|
468
477
|
def accelerate_disk_offload(
|
|
478
|
+
model: "PreTrainedModel",
|
|
469
479
|
disk_offload_folder: str | None,
|
|
470
480
|
checkpoint_files: list[str] | None,
|
|
471
481
|
device_map: dict,
|
|
472
|
-
expected_keys: list[str],
|
|
473
482
|
sharded_metadata: dict | None,
|
|
474
483
|
dtype: torch.dtype | None,
|
|
475
484
|
weight_mapping=None,
|
|
@@ -493,7 +502,8 @@ def accelerate_disk_offload(
|
|
|
493
502
|
# In this case, the offload index is simply the existing safetensors (except if using custom weight loading
|
|
494
503
|
# Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
|
|
495
504
|
if is_offloaded_safetensors:
|
|
496
|
-
|
|
505
|
+
meta_state_dict = model.state_dict()
|
|
506
|
+
param_device_map = expand_device_map(device_map, meta_state_dict.keys())
|
|
497
507
|
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
|
498
508
|
if sharded_metadata is None:
|
|
499
509
|
weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0])
|
|
@@ -502,7 +512,9 @@ def accelerate_disk_offload(
|
|
|
502
512
|
weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()}
|
|
503
513
|
|
|
504
514
|
# Update the weight names according to the `weight_mapping`
|
|
505
|
-
weight_renaming_map = {
|
|
515
|
+
weight_renaming_map = {
|
|
516
|
+
rename_source_key(k, renamings, [], model.base_model_prefix, meta_state_dict)[0]: k for k in weight_map
|
|
517
|
+
}
|
|
506
518
|
|
|
507
519
|
# Prepare the index using existing safetensors files
|
|
508
520
|
disk_offload_index = {
|
|
@@ -13,88 +13,60 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"AQLM (Additive Quantization of Language Model) integration file"
|
|
15
15
|
|
|
16
|
-
from ..
|
|
16
|
+
from ..quantizers.quantizers_utils import should_convert_module
|
|
17
|
+
from ..utils import is_accelerate_available, is_torch_available, logging
|
|
17
18
|
|
|
18
19
|
|
|
20
|
+
if is_accelerate_available():
|
|
21
|
+
from accelerate import init_empty_weights
|
|
22
|
+
|
|
19
23
|
if is_torch_available():
|
|
20
24
|
import torch.nn as nn
|
|
21
25
|
|
|
26
|
+
logger = logging.get_logger(__name__)
|
|
22
27
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
quantization_config=None,
|
|
26
|
-
linear_weights_not_to_quantize=None,
|
|
27
|
-
current_key_name=None,
|
|
28
|
-
has_been_replaced=False,
|
|
29
|
-
):
|
|
28
|
+
|
|
29
|
+
def replace_with_aqlm_linear(model, modules_to_not_convert: list[str] | None = None, quantization_config=None):
|
|
30
30
|
"""
|
|
31
31
|
Public method that recursively replaces the Linear layers of the given model with AQLM quantized layers.
|
|
32
|
-
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
|
|
33
|
-
conversion has been successful or not.
|
|
34
32
|
|
|
35
33
|
Args:
|
|
36
34
|
model (`torch.nn.Module`):
|
|
37
35
|
The model to convert, can be any `torch.nn.Module` instance.
|
|
38
|
-
|
|
39
|
-
The quantization config object that contains the quantization parameters.
|
|
40
|
-
linear_weights_not_to_quantize (`list[str]`, *optional*):
|
|
36
|
+
modules_to_not_convert (`list[str]`, *optional*, defaults to `None`):
|
|
41
37
|
A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
|
|
42
38
|
converted.
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
has_been_replaced (`bool`, *optional*):
|
|
46
|
-
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
|
|
47
|
-
should not be passed by the user.
|
|
39
|
+
quantization_config (`AqlmConfig`):
|
|
40
|
+
The quantization config object that contains the quantization parameters.
|
|
48
41
|
"""
|
|
49
|
-
if not is_aqlm_available():
|
|
50
|
-
raise ValueError("AQLM is not available. Please install it with `pip install aqlm[cpu,gpu]`")
|
|
51
|
-
|
|
52
|
-
if not is_accelerate_available():
|
|
53
|
-
raise ValueError(
|
|
54
|
-
f"AQLM requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
if linear_weights_not_to_quantize is None:
|
|
58
|
-
linear_weights_not_to_quantize = []
|
|
59
|
-
|
|
60
|
-
from accelerate import init_empty_weights
|
|
61
42
|
from aqlm import QuantizedLinear
|
|
62
43
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
44
|
+
has_been_replaced = False
|
|
45
|
+
# we need this to correctly materialize the weights during quantization
|
|
46
|
+
for module_name, module in model.named_modules():
|
|
47
|
+
if not should_convert_module(module_name, modules_to_not_convert):
|
|
48
|
+
continue
|
|
49
|
+
with init_empty_weights():
|
|
50
|
+
if isinstance(module, nn.Linear):
|
|
51
|
+
new_module = QuantizedLinear(
|
|
52
|
+
module.in_features,
|
|
53
|
+
module.out_features,
|
|
54
|
+
bias=module.bias is not None,
|
|
55
|
+
in_group_size=quantization_config.in_group_size,
|
|
56
|
+
out_group_size=quantization_config.out_group_size,
|
|
57
|
+
num_codebooks=quantization_config.num_codebooks,
|
|
58
|
+
nbits_per_codebook=quantization_config.nbits_per_codebook,
|
|
59
|
+
)
|
|
60
|
+
new_module.source_cls = type(module)
|
|
61
|
+
new_module.requires_grad_(False)
|
|
62
|
+
model.set_submodule(module_name, new_module)
|
|
63
|
+
has_been_replaced = True
|
|
74
64
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
num_codebooks=quantization_config.num_codebooks,
|
|
82
|
-
nbits_per_codebook=quantization_config.nbits_per_codebook,
|
|
83
|
-
)
|
|
84
|
-
has_been_replaced = True
|
|
65
|
+
if not has_been_replaced:
|
|
66
|
+
logger.warning(
|
|
67
|
+
"You are loading your model using eetq but no linear modules were found in your model."
|
|
68
|
+
" Please double check your model architecture, or submit an issue on github if you think this is"
|
|
69
|
+
" a bug."
|
|
70
|
+
)
|
|
85
71
|
|
|
86
|
-
|
|
87
|
-
model._modules[name].source_cls = type(module)
|
|
88
|
-
# Force requires grad to False to avoid unexpected errors
|
|
89
|
-
model._modules[name].requires_grad_(False)
|
|
90
|
-
if len(list(module.children())) > 0:
|
|
91
|
-
_, has_been_replaced = replace_with_aqlm_linear(
|
|
92
|
-
module,
|
|
93
|
-
quantization_config=quantization_config,
|
|
94
|
-
linear_weights_not_to_quantize=linear_weights_not_to_quantize,
|
|
95
|
-
current_key_name=current_key_name,
|
|
96
|
-
has_been_replaced=has_been_replaced,
|
|
97
|
-
)
|
|
98
|
-
# Remove the last key for recursion
|
|
99
|
-
current_key_name.pop(-1)
|
|
100
|
-
return model, has_been_replaced
|
|
72
|
+
return model
|