transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
transformers/quantizers/base.py
CHANGED
|
@@ -12,17 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
from abc import ABC, abstractmethod
|
|
15
|
-
from copy import deepcopy
|
|
16
15
|
from typing import TYPE_CHECKING, Any
|
|
17
16
|
|
|
18
|
-
from ..utils import
|
|
17
|
+
from ..utils import is_torch_available, logging
|
|
19
18
|
from ..utils.quantization_config import QuantizationConfigMixin, QuantizationMethod
|
|
20
19
|
from .quantizers_utils import get_module_from_name
|
|
21
20
|
|
|
22
21
|
|
|
23
|
-
if is_accelerate_available():
|
|
24
|
-
from accelerate.utils import find_tied_parameters
|
|
25
|
-
|
|
26
22
|
if TYPE_CHECKING:
|
|
27
23
|
from ..modeling_utils import PreTrainedModel
|
|
28
24
|
|
|
@@ -45,50 +41,31 @@ def _assign_original_dtype(module, original_dtype):
|
|
|
45
41
|
_assign_original_dtype(child, original_dtype)
|
|
46
42
|
|
|
47
43
|
|
|
48
|
-
def get_keys_to_not_convert(model):
|
|
44
|
+
def get_keys_to_not_convert(model) -> list:
|
|
49
45
|
r"""
|
|
50
|
-
|
|
51
|
-
we may want to keep the lm_head in full precision for numerical stability reasons.
|
|
52
|
-
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
|
|
53
|
-
int8.
|
|
54
|
-
|
|
55
|
-
Parameters:
|
|
56
|
-
model (`torch.nn.Module`):
|
|
57
|
-
Input model
|
|
46
|
+
Function to automatically detect keys to not convert for usage like quantization. For example for CausalLM modules
|
|
47
|
+
we may want to keep the lm_head in full precision for numerical stability reasons.
|
|
58
48
|
"""
|
|
59
|
-
#
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
intersection = set(list_last_module) - set(tied_keys)
|
|
80
|
-
list_untouched = list(set(tied_keys)) + list(intersection)
|
|
81
|
-
|
|
82
|
-
# remove ".weight" from the keys
|
|
83
|
-
names_to_remove = [".weight", ".bias"]
|
|
84
|
-
filtered_module_names = []
|
|
85
|
-
for name in list_untouched:
|
|
86
|
-
for name_to_remove in names_to_remove:
|
|
87
|
-
if name_to_remove in name:
|
|
88
|
-
name = name.replace(name_to_remove, "")
|
|
89
|
-
filtered_module_names.append(name)
|
|
90
|
-
|
|
91
|
-
return filtered_module_names
|
|
49
|
+
# remove tied weights
|
|
50
|
+
tied_keys = set()
|
|
51
|
+
if len(model.all_tied_weights_keys) > 0:
|
|
52
|
+
tied_keys = set(model.all_tied_weights_keys.values()) | set(model.all_tied_weights_keys.keys())
|
|
53
|
+
|
|
54
|
+
# remove last module
|
|
55
|
+
last_module_key = {list(model.named_parameters())[-1][0]}
|
|
56
|
+
|
|
57
|
+
# remove output emb
|
|
58
|
+
output_emb_module = model.get_output_embeddings()
|
|
59
|
+
output_emb_keys = {
|
|
60
|
+
name
|
|
61
|
+
for name, module in model.named_modules()
|
|
62
|
+
if output_emb_module is not None and id(module) == id(output_emb_module)
|
|
63
|
+
}
|
|
64
|
+
modules_to_not_convert = tied_keys | last_module_key | output_emb_keys
|
|
65
|
+
|
|
66
|
+
modules_to_not_convert = list({k.removesuffix(".weight") for k in modules_to_not_convert})
|
|
67
|
+
|
|
68
|
+
return list(modules_to_not_convert)
|
|
92
69
|
|
|
93
70
|
|
|
94
71
|
class HfQuantizer(ABC):
|
|
@@ -100,26 +77,14 @@ class HfQuantizer(ABC):
|
|
|
100
77
|
Attributes
|
|
101
78
|
quantization_config (`transformers.utils.quantization_config.QuantizationConfigMixin`):
|
|
102
79
|
The quantization config that defines the quantization parameters of your model that you want to quantize.
|
|
103
|
-
modules_to_not_convert (`list[str]`, *optional*):
|
|
104
|
-
The list of module names to not convert when quantizing the model.
|
|
105
|
-
required_packages (`list[str]`, *optional*):
|
|
106
|
-
The list of required pip packages to install prior to using the quantizer
|
|
107
80
|
requires_calibration (`bool`):
|
|
108
81
|
Whether the quantization method requires to calibrate the model before using it.
|
|
109
|
-
requires_parameters_quantization (`bool`):
|
|
110
|
-
Whether the quantization method requires to create a new Parameter. For example, for bitsandbytes, it is
|
|
111
|
-
required to create a new xxxParameter in order to properly quantize the model.
|
|
112
82
|
"""
|
|
113
83
|
|
|
114
84
|
requires_calibration = False
|
|
115
|
-
required_packages = None
|
|
116
|
-
requires_parameters_quantization = False
|
|
117
85
|
|
|
118
86
|
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
|
119
87
|
self.quantization_config = quantization_config
|
|
120
|
-
|
|
121
|
-
# -- Handle extra kwargs below --
|
|
122
|
-
self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
|
|
123
88
|
self.pre_quantized = kwargs.pop("pre_quantized", True)
|
|
124
89
|
|
|
125
90
|
if not self.pre_quantized and self.requires_calibration:
|
|
@@ -182,53 +147,16 @@ class HfQuantizer(ABC):
|
|
|
182
147
|
return mapping[custom_dtype]
|
|
183
148
|
return param.element_size()
|
|
184
149
|
|
|
185
|
-
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
|
|
186
|
-
"""
|
|
187
|
-
Override this method if you want to adjust the `missing_keys`.
|
|
188
|
-
|
|
189
|
-
Args:
|
|
190
|
-
missing_keys (`list[str]`, *optional*):
|
|
191
|
-
The list of missing keys in the checkpoint compared to the state dict of the model
|
|
192
|
-
"""
|
|
193
|
-
return missing_keys
|
|
194
|
-
|
|
195
|
-
def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: list[str]) -> list[str]:
|
|
196
|
-
"""
|
|
197
|
-
Override this method if you want to adjust the `update_expected_keys`.
|
|
198
|
-
|
|
199
|
-
Args:
|
|
200
|
-
expected_keys (`list[str]`, *optional*):
|
|
201
|
-
The list of the expected keys in the initialized model.
|
|
202
|
-
loaded_keys (`list[str]`, *optional*):
|
|
203
|
-
The list of the loaded keys in the checkpoint.
|
|
204
|
-
"""
|
|
205
|
-
return expected_keys
|
|
206
|
-
|
|
207
|
-
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
|
|
208
|
-
return unexpected_keys
|
|
209
|
-
|
|
210
150
|
def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
|
|
211
151
|
"""adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
|
|
212
152
|
return max_memory
|
|
213
153
|
|
|
214
154
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
215
155
|
"""
|
|
216
|
-
Check whether a given param needs
|
|
156
|
+
Check whether a given param needs to be quantized.
|
|
217
157
|
"""
|
|
218
158
|
return False
|
|
219
159
|
|
|
220
|
-
def create_quantized_param(self, *args, **kwargs):
|
|
221
|
-
"""
|
|
222
|
-
Take needed components from state_dict (those from which `param_needs_quantization` is True) and create
|
|
223
|
-
quantized param.
|
|
224
|
-
It usually also load the new param directly in the `model`.
|
|
225
|
-
Note: only applicable if requires_parameters_quantization == True.
|
|
226
|
-
"""
|
|
227
|
-
if not self.requires_parameters_quantization:
|
|
228
|
-
raise AttributeError(
|
|
229
|
-
f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}."
|
|
230
|
-
)
|
|
231
|
-
|
|
232
160
|
def validate_environment(self, *args, **kwargs):
|
|
233
161
|
"""
|
|
234
162
|
This method is used to potentially check for potential conflicts with arguments that are
|
|
@@ -288,6 +216,11 @@ class HfQuantizer(ABC):
|
|
|
288
216
|
kwargs (`dict`, *optional*):
|
|
289
217
|
The keyword arguments that are passed along `_process_model_after_weight_loading`.
|
|
290
218
|
"""
|
|
219
|
+
model.config.quantization_config = self.quantization_config
|
|
220
|
+
|
|
221
|
+
if self.pre_quantized and getattr(self.quantization_config, "dequantize", False):
|
|
222
|
+
self.remove_quantization_config(model)
|
|
223
|
+
|
|
291
224
|
return self._process_model_after_weight_loading(model, **kwargs)
|
|
292
225
|
|
|
293
226
|
def remove_quantization_config(self, model):
|
|
@@ -310,13 +243,7 @@ class HfQuantizer(ABC):
|
|
|
310
243
|
Note not all quantization schemes support this.
|
|
311
244
|
"""
|
|
312
245
|
model = self._dequantize(model)
|
|
313
|
-
|
|
314
|
-
# Delete quantizer and quantization config
|
|
315
|
-
del model.hf_quantizer
|
|
316
|
-
del model.config.quantization_config
|
|
317
|
-
del model.config._pre_quantization_dtype
|
|
318
|
-
del model.quantization_method
|
|
319
|
-
model.is_quantized = False
|
|
246
|
+
self.remove_quantization_config(model)
|
|
320
247
|
|
|
321
248
|
return model
|
|
322
249
|
|
|
@@ -360,6 +287,8 @@ class HfQuantizer(ABC):
|
|
|
360
287
|
if keep_in_fp32_modules is not None:
|
|
361
288
|
modules_to_not_convert.extend(keep_in_fp32_modules)
|
|
362
289
|
|
|
290
|
+
modules_to_not_convert = list(set(modules_to_not_convert))
|
|
291
|
+
|
|
363
292
|
return modules_to_not_convert
|
|
364
293
|
|
|
365
294
|
@property
|
|
@@ -372,16 +301,12 @@ class HfQuantizer(ABC):
|
|
|
372
301
|
"""Flag indicating whether the quantized model can be compiled"""
|
|
373
302
|
return False
|
|
374
303
|
|
|
375
|
-
def get_state_dict_and_metadata(self, model
|
|
304
|
+
def get_state_dict_and_metadata(self, model):
|
|
376
305
|
"""Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
|
|
377
306
|
return None, {}
|
|
378
307
|
|
|
379
|
-
def update_state_dict_with_metadata(self, state_dict, metadata):
|
|
380
|
-
"""Update state dict with metadata. Default behaviour returns state_dict"""
|
|
381
|
-
return state_dict
|
|
382
|
-
|
|
383
308
|
@abstractmethod
|
|
384
|
-
def is_serializable(self
|
|
309
|
+
def is_serializable(self): ...
|
|
385
310
|
|
|
386
311
|
@property
|
|
387
312
|
@abstractmethod
|
|
@@ -39,12 +39,9 @@ class AqlmHfQuantizer(HfQuantizer):
|
|
|
39
39
|
"""
|
|
40
40
|
|
|
41
41
|
requires_calibration = True
|
|
42
|
-
required_packages = ["aqlm"]
|
|
43
|
-
optimum_quantizer = None
|
|
44
42
|
|
|
45
43
|
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
|
46
44
|
super().__init__(quantization_config, **kwargs)
|
|
47
|
-
self.quantization_config = quantization_config
|
|
48
45
|
|
|
49
46
|
def validate_environment(self, *args, **kwargs):
|
|
50
47
|
if not is_accelerate_available():
|
|
@@ -77,7 +74,6 @@ class AqlmHfQuantizer(HfQuantizer):
|
|
|
77
74
|
quantization_config=self.quantization_config,
|
|
78
75
|
linear_weights_not_to_quantize=self.quantization_config.linear_weights_not_to_quantize,
|
|
79
76
|
)
|
|
80
|
-
model.config.quantization_config = self.quantization_config
|
|
81
77
|
|
|
82
78
|
@property
|
|
83
79
|
def is_trainable(self) -> bool:
|
|
@@ -90,5 +86,5 @@ class AqlmHfQuantizer(HfQuantizer):
|
|
|
90
86
|
)
|
|
91
87
|
return False
|
|
92
88
|
|
|
93
|
-
def is_serializable(self
|
|
89
|
+
def is_serializable(self):
|
|
94
90
|
return True
|
|
@@ -36,7 +36,6 @@ class AutoRoundQuantizer(HfQuantizer):
|
|
|
36
36
|
|
|
37
37
|
# AutoRound requires data calibration - we support only inference
|
|
38
38
|
requires_calibration = True
|
|
39
|
-
required_packages = ["auto_round"]
|
|
40
39
|
|
|
41
40
|
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
|
42
41
|
super().__init__(quantization_config, **kwargs)
|
|
@@ -76,6 +75,6 @@ class AutoRoundQuantizer(HfQuantizer):
|
|
|
76
75
|
def is_trainable(self) -> bool:
|
|
77
76
|
return False
|
|
78
77
|
|
|
79
|
-
def is_serializable(self
|
|
78
|
+
def is_serializable(self):
|
|
80
79
|
## for gptq/awq models, the quantization config will be changed
|
|
81
80
|
return True
|
|
@@ -22,8 +22,8 @@ from .base import HfQuantizer
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
23
|
from ..modeling_utils import PreTrainedModel
|
|
24
24
|
|
|
25
|
-
from ..utils import is_accelerate_available,
|
|
26
|
-
from ..utils.quantization_config import
|
|
25
|
+
from ..utils import is_accelerate_available, is_gptqmodel_available, is_torch_available, logging
|
|
26
|
+
from ..utils.quantization_config import AwqBackend
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
if is_torch_available():
|
|
@@ -40,60 +40,18 @@ class AwqQuantizer(HfQuantizer):
|
|
|
40
40
|
# AWQ requires data calibration - we support only inference
|
|
41
41
|
requires_calibration = True
|
|
42
42
|
|
|
43
|
-
required_packages = ["awq", "accelerate"]
|
|
44
|
-
|
|
45
43
|
def __init__(self, quantization_config, **kwargs):
|
|
46
44
|
super().__init__(quantization_config, **kwargs)
|
|
47
45
|
|
|
48
|
-
def validate_environment(self,
|
|
49
|
-
if not
|
|
50
|
-
raise ImportError(
|
|
46
|
+
def validate_environment(self, **kwargs):
|
|
47
|
+
if not is_gptqmodel_available():
|
|
48
|
+
raise ImportError(
|
|
49
|
+
"Loading an AWQ quantized model requires gptqmodel. Please install it with `pip install gptqmodel`"
|
|
50
|
+
)
|
|
51
51
|
|
|
52
52
|
if not is_accelerate_available():
|
|
53
53
|
raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)")
|
|
54
54
|
|
|
55
|
-
if (
|
|
56
|
-
self.quantization_config.version == AWQLinearVersion.GEMM
|
|
57
|
-
and not torch.cuda.is_available()
|
|
58
|
-
and not torch.xpu.is_available()
|
|
59
|
-
):
|
|
60
|
-
logger.warning_once("No CUDA or XPU found, consider switching to the IPEX version for CPU-only execution.")
|
|
61
|
-
self.quantization_config.version = AWQLinearVersion.IPEX
|
|
62
|
-
|
|
63
|
-
if self.quantization_config.version == AWQLinearVersion.IPEX:
|
|
64
|
-
if version.parse(importlib.metadata.version("autoawq")) < version.parse("0.2.6"):
|
|
65
|
-
raise RuntimeError(
|
|
66
|
-
"To use IPEX backend, you need autoawq>0.2.6. Please install the latest version or from source."
|
|
67
|
-
)
|
|
68
|
-
if device_map is None:
|
|
69
|
-
logger.warning_once(
|
|
70
|
-
"You have loaded an AWQ model without setting device_map, please set 'cpu' or 'xpu' or 'auto'"
|
|
71
|
-
)
|
|
72
|
-
elif isinstance(device_map, dict) and "disk" in device_map.values():
|
|
73
|
-
raise ValueError(
|
|
74
|
-
"You are attempting to load an IPEX version AWQ model with a device_map that contains disk device."
|
|
75
|
-
" This is not supported. Please make sure only cpu and xpu in the device_map."
|
|
76
|
-
)
|
|
77
|
-
else:
|
|
78
|
-
if not torch.cuda.is_available() and not torch.xpu.is_available():
|
|
79
|
-
raise RuntimeError(
|
|
80
|
-
"GPU is required to run AWQ quantized model. You can use IPEX version AWQ if you have an Intel CPU"
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
if device_map is None:
|
|
84
|
-
logger.warning_once(
|
|
85
|
-
"You have loaded an AWQ model on CPU and have a CUDA/XPU device available, make sure to set "
|
|
86
|
-
"your model on a GPU device in order to run your model."
|
|
87
|
-
)
|
|
88
|
-
elif device_map is not None:
|
|
89
|
-
if isinstance(device_map, dict) and any(
|
|
90
|
-
forbidden in device_map.values() for forbidden in ("cpu", torch.device("cpu"), "disk")
|
|
91
|
-
):
|
|
92
|
-
raise ValueError(
|
|
93
|
-
"You are attempting to load an AWQ model with a device_map that contains a CPU or disk device."
|
|
94
|
-
" This is not supported. Please remove the CPU or disk device from the device_map."
|
|
95
|
-
)
|
|
96
|
-
|
|
97
55
|
def update_dtype(self, dtype):
|
|
98
56
|
if dtype is None:
|
|
99
57
|
dtype = torch.float16
|
|
@@ -116,42 +74,22 @@ class AwqQuantizer(HfQuantizer):
|
|
|
116
74
|
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules, add_default_skips=True
|
|
117
75
|
)
|
|
118
76
|
|
|
119
|
-
model
|
|
120
|
-
model,
|
|
77
|
+
model = replace_with_awq_linear(
|
|
78
|
+
model,
|
|
79
|
+
quantization_config=self.quantization_config,
|
|
80
|
+
modules_to_not_convert=self.modules_to_not_convert,
|
|
81
|
+
device_map=kwargs.get("device_map"),
|
|
121
82
|
)
|
|
122
83
|
|
|
123
84
|
model = replace_quantization_scales(model, model.config.model_type)
|
|
124
85
|
|
|
125
|
-
if not has_been_replaced:
|
|
126
|
-
logger.warning(
|
|
127
|
-
"You are loading an AWQ model but no linear modules were found in your model."
|
|
128
|
-
" Please double check your model architecture, or submit an issue on github if you think this is a bug."
|
|
129
|
-
)
|
|
130
|
-
|
|
131
86
|
def _process_model_after_weight_loading(self, model, **kwargs):
|
|
132
|
-
|
|
133
|
-
from ..integrations import fuse_awq_modules
|
|
87
|
+
from gptqmodel.utils.model import hf_gptqmodel_post_init
|
|
134
88
|
|
|
135
|
-
|
|
136
|
-
model._awq_is_fused = True # TODO: consider storing this flag in model.config instead
|
|
137
|
-
|
|
138
|
-
if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
|
|
139
|
-
from ..integrations import post_init_awq_exllama_modules
|
|
140
|
-
|
|
141
|
-
model = post_init_awq_exllama_modules(model, self.quantization_config.exllama_config)
|
|
142
|
-
|
|
143
|
-
if self.quantization_config.version == AWQLinearVersion.IPEX:
|
|
144
|
-
from ..integrations import post_init_awq_ipex_modules
|
|
145
|
-
|
|
146
|
-
model = post_init_awq_ipex_modules(model)
|
|
147
|
-
|
|
148
|
-
def is_serializable(self, safe_serialization=None):
|
|
149
|
-
# AWQ through auto-awq has been always serializable, except if the model is fused.
|
|
150
|
-
if self.quantization_config.do_fuse:
|
|
151
|
-
logger.warning("You cannot save an AWQ model that uses fused modules!")
|
|
152
|
-
return False
|
|
89
|
+
hf_gptqmodel_post_init(model, use_act_order=self.quantization_config.desc_act)
|
|
153
90
|
|
|
154
|
-
|
|
91
|
+
def is_serializable(self):
|
|
92
|
+
if self.quantization_config.backend in [AwqBackend.EXLLAMA_V1, AwqBackend.EXLLAMA_V2]:
|
|
155
93
|
logger.warning("You cannot save an AWQ model that uses Exllama backend!")
|
|
156
94
|
return False
|
|
157
95
|
|
|
@@ -159,6 +97,4 @@ class AwqQuantizer(HfQuantizer):
|
|
|
159
97
|
|
|
160
98
|
@property
|
|
161
99
|
def is_trainable(self):
|
|
162
|
-
|
|
163
|
-
MIN_AWQ_VERSION_FOR_PEFT = "0.2.0"
|
|
164
|
-
return version.parse(importlib.metadata.version("autoawq")) >= version.parse(MIN_AWQ_VERSION_FOR_PEFT)
|
|
100
|
+
return version.parse(importlib.metadata.version("gptqmodel")) >= version.parse("5.0.0")
|
|
@@ -37,14 +37,10 @@ class BitNetHfQuantizer(HfQuantizer):
|
|
|
37
37
|
Check out the paper introducing this method: https://huggingface.co/papers/2402.17764
|
|
38
38
|
"""
|
|
39
39
|
|
|
40
|
-
requires_parameters_quantization = False
|
|
41
40
|
requires_calibration = True
|
|
42
41
|
|
|
43
|
-
required_packages = ["accelerate"]
|
|
44
|
-
|
|
45
42
|
def __init__(self, quantization_config, **kwargs):
|
|
46
43
|
super().__init__(quantization_config, **kwargs)
|
|
47
|
-
self.quantization_config = quantization_config
|
|
48
44
|
|
|
49
45
|
def validate_environment(self, *args, **kwargs):
|
|
50
46
|
if not is_accelerate_available():
|
|
@@ -62,8 +58,8 @@ class BitNetHfQuantizer(HfQuantizer):
|
|
|
62
58
|
"You have loaded a BitNet model on CPU and have a CUDA device available, make sure to set "
|
|
63
59
|
"your model on a GPU device in order to run your model."
|
|
64
60
|
)
|
|
65
|
-
elif device_map
|
|
66
|
-
if
|
|
61
|
+
elif isinstance(device_map, dict):
|
|
62
|
+
if len(device_map) > 1 and "cpu" in device_map.values() or "disk" in device_map.values():
|
|
67
63
|
raise ValueError(
|
|
68
64
|
"You are attempting to load a BitNet model with a device_map that contains a CPU or disk device."
|
|
69
65
|
"This is not supported. Please remove the CPU or disk device from the device_map."
|
|
@@ -85,7 +81,6 @@ class BitNetHfQuantizer(HfQuantizer):
|
|
|
85
81
|
model,
|
|
86
82
|
modules_to_not_convert=self.modules_to_not_convert,
|
|
87
83
|
quantization_config=self.quantization_config,
|
|
88
|
-
pre_quantized=self.pre_quantized,
|
|
89
84
|
)
|
|
90
85
|
|
|
91
86
|
def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
|
|
@@ -96,7 +91,7 @@ class BitNetHfQuantizer(HfQuantizer):
|
|
|
96
91
|
target_dtype = torch.int8
|
|
97
92
|
return target_dtype
|
|
98
93
|
|
|
99
|
-
def is_serializable(self
|
|
94
|
+
def is_serializable(self):
|
|
100
95
|
return True
|
|
101
96
|
|
|
102
97
|
@property
|
|
@@ -11,7 +11,6 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
from collections import defaultdict
|
|
15
14
|
from typing import TYPE_CHECKING
|
|
16
15
|
|
|
17
16
|
from .base import HfQuantizer
|
|
@@ -38,34 +37,20 @@ if is_torch_available():
|
|
|
38
37
|
import torch
|
|
39
38
|
|
|
40
39
|
from ..core_model_loading import WeightConverter
|
|
41
|
-
from ..pytorch_utils import Conv1D
|
|
42
40
|
|
|
43
41
|
logger = logging.get_logger(__name__)
|
|
44
42
|
|
|
45
43
|
|
|
46
44
|
class Bnb4BitHfQuantizer(HfQuantizer):
|
|
47
45
|
"""
|
|
48
|
-
4-bit quantization from bitsandbytes quantization method
|
|
49
|
-
before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the
|
|
50
|
-
layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call
|
|
51
|
-
saving:
|
|
52
|
-
from state dict, as usual; saves weights and `quant_state` components
|
|
53
|
-
loading:
|
|
54
|
-
need to locate `quant_state` components and pass to Param4bit constructor
|
|
46
|
+
4-bit quantization from bitsandbytes quantization method
|
|
55
47
|
"""
|
|
56
48
|
|
|
57
|
-
use_keep_in_fp32_modules = True
|
|
58
|
-
requires_parameters_quantization = True
|
|
59
49
|
requires_calibration = False
|
|
60
50
|
|
|
61
|
-
required_packages = ["bitsandbytes", "accelerate"]
|
|
62
|
-
|
|
63
51
|
def __init__(self, quantization_config, **kwargs):
|
|
64
52
|
super().__init__(quantization_config, **kwargs)
|
|
65
53
|
|
|
66
|
-
if self.quantization_config.llm_int8_skip_modules is not None:
|
|
67
|
-
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
|
|
68
|
-
|
|
69
54
|
# This describes the additional items that are saved on the state dict (on the params themselves)
|
|
70
55
|
self.bnb_keys = [
|
|
71
56
|
f"quant_state.bitsandbytes__{self.quantization_config.bnb_4bit_quant_type}",
|
|
@@ -90,17 +75,9 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|
|
90
75
|
validate_bnb_backend_availability(raise_exception=True)
|
|
91
76
|
|
|
92
77
|
device_map = kwargs.get("device_map")
|
|
93
|
-
if (
|
|
94
|
-
|
|
95
|
-
and
|
|
96
|
-
and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
|
|
97
|
-
):
|
|
98
|
-
device_map_without_lm_head = {
|
|
99
|
-
key: device_map[key] for key in device_map if key not in self.modules_to_not_convert
|
|
100
|
-
}
|
|
101
|
-
if set(device_map.values()) == {"cpu"}:
|
|
102
|
-
pass
|
|
103
|
-
elif "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
|
|
78
|
+
if not self.quantization_config.llm_int8_enable_fp32_cpu_offload and isinstance(device_map, dict):
|
|
79
|
+
values = set(device_map.values())
|
|
80
|
+
if values != {"cpu"} and ("cpu" in values or "disk" in values):
|
|
104
81
|
raise ValueError(
|
|
105
82
|
"Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
|
|
106
83
|
"quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
|
|
@@ -117,13 +94,11 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|
|
117
94
|
logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization")
|
|
118
95
|
return CustomDtype.INT4
|
|
119
96
|
|
|
120
|
-
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
|
|
121
|
-
return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.bnb_keys)]
|
|
122
|
-
|
|
123
97
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
124
98
|
import bitsandbytes as bnb
|
|
125
99
|
|
|
126
|
-
#
|
|
100
|
+
# TODO: maybe remove
|
|
101
|
+
# # They are on the params themselves, so we cannot easily extract the module from the name
|
|
127
102
|
if any(param_name.endswith(x) for x in self.bnb_keys):
|
|
128
103
|
return True
|
|
129
104
|
module, name = get_module_from_name(model, param_name)
|
|
@@ -142,71 +117,13 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|
|
142
117
|
)
|
|
143
118
|
return param_name
|
|
144
119
|
|
|
145
|
-
def create_quantized_param(
|
|
146
|
-
self,
|
|
147
|
-
model: "PreTrainedModel",
|
|
148
|
-
param_value: "torch.Tensor",
|
|
149
|
-
param_name: str,
|
|
150
|
-
target_device: "torch.device",
|
|
151
|
-
**kwargs,
|
|
152
|
-
):
|
|
153
|
-
import bitsandbytes as bnb
|
|
154
|
-
|
|
155
|
-
full_name = param_name
|
|
156
|
-
|
|
157
|
-
# update param name to get the weights instead of the quantized stats
|
|
158
|
-
param_name = self.get_param_name(param_name)
|
|
159
|
-
module, tensor_name = get_module_from_name(model, param_name)
|
|
160
|
-
|
|
161
|
-
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
|
|
162
|
-
if isinstance(target_device, int) and is_torch_npu_available():
|
|
163
|
-
target_device = f"npu:{target_device}"
|
|
164
|
-
|
|
165
|
-
# construct `new_value` for the module._parameters[tensor_name]
|
|
166
|
-
if self.pre_quantized:
|
|
167
|
-
module_name = param_name.rsplit(".", 1)[0]
|
|
168
|
-
# Save the states for later quantization when they are all gathered
|
|
169
|
-
if not hasattr(self, "param_quant_stats"):
|
|
170
|
-
self.param_quant_stats = defaultdict(dict)
|
|
171
|
-
self.param_quant_stats[module_name].update({full_name: param_value})
|
|
172
|
-
|
|
173
|
-
# We are ready for quantization in this case (note, the +1 is for the weight itself)
|
|
174
|
-
if len(self.param_quant_stats[module_name]) == len(self.bnb_keys) + 1:
|
|
175
|
-
weight = self.param_quant_stats[module_name].pop(f"{module_name}.weight")
|
|
176
|
-
new_value = bnb.nn.Params4bit.from_prequantized(
|
|
177
|
-
data=weight,
|
|
178
|
-
quantized_stats=self.param_quant_stats[module_name],
|
|
179
|
-
requires_grad=False,
|
|
180
|
-
device=target_device,
|
|
181
|
-
module=module,
|
|
182
|
-
)
|
|
183
|
-
# Set it
|
|
184
|
-
module._parameters[tensor_name] = new_value
|
|
185
|
-
# Delete the states
|
|
186
|
-
del self.param_quant_stats[module_name]
|
|
187
|
-
else:
|
|
188
|
-
new_value = param_value.to("cpu")
|
|
189
|
-
old_value = getattr(module, tensor_name)
|
|
190
|
-
|
|
191
|
-
# Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
|
|
192
|
-
# Since weights are saved in the correct "orientation", we skip transposing when loading.
|
|
193
|
-
if issubclass(module.source_cls, Conv1D):
|
|
194
|
-
new_value = new_value.T
|
|
195
|
-
|
|
196
|
-
kwargs = old_value.__dict__
|
|
197
|
-
kwargs.pop("_is_hf_initialized", None)
|
|
198
|
-
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
|
|
199
|
-
|
|
200
|
-
module._parameters[tensor_name] = new_value
|
|
201
|
-
|
|
202
|
-
# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.adjust_max_memory
|
|
203
120
|
def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
|
|
204
121
|
# need more space for buffers that are created during quantization
|
|
205
122
|
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
|
206
123
|
return max_memory
|
|
207
124
|
|
|
208
|
-
# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_dtype
|
|
209
125
|
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
126
|
+
# TODO: remove ? is it still true ? we will move to dtype = "auto" so it will likely be either fp16 or bf16
|
|
210
127
|
if dtype is None:
|
|
211
128
|
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
|
|
212
129
|
logger.info(
|
|
@@ -238,7 +155,6 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|
|
238
155
|
)
|
|
239
156
|
return device_map
|
|
240
157
|
|
|
241
|
-
# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_before_weight_loading
|
|
242
158
|
def _process_model_before_weight_loading(
|
|
243
159
|
self,
|
|
244
160
|
model: "PreTrainedModel",
|
|
@@ -248,23 +164,15 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|
|
248
164
|
):
|
|
249
165
|
from ..integrations import replace_with_bnb_linear
|
|
250
166
|
|
|
251
|
-
llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
|
|
252
|
-
|
|
253
167
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
254
168
|
model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
|
|
255
169
|
)
|
|
256
170
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
171
|
+
if self.quantization_config.llm_int8_enable_fp32_cpu_offload:
|
|
172
|
+
if isinstance(device_map, dict):
|
|
173
|
+
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
|
|
174
|
+
self.modules_to_not_convert.extend(keys_on_cpu)
|
|
260
175
|
|
|
261
|
-
if len(keys_on_cpu) > 0 and not llm_int8_enable_fp32_cpu_offload:
|
|
262
|
-
raise ValueError(
|
|
263
|
-
"If you want to offload some keys to `cpu` or `disk`, you need to set "
|
|
264
|
-
"`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
|
|
265
|
-
" converted to 8-bit but kept in 32-bit."
|
|
266
|
-
)
|
|
267
|
-
self.modules_to_not_convert.extend(keys_on_cpu)
|
|
268
176
|
model = replace_with_bnb_linear(
|
|
269
177
|
model,
|
|
270
178
|
modules_to_not_convert=self.modules_to_not_convert,
|
|
@@ -272,15 +180,12 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|
|
272
180
|
pre_quantized=self.pre_quantized,
|
|
273
181
|
)
|
|
274
182
|
|
|
275
|
-
model.config.quantization_config = self.quantization_config
|
|
276
|
-
|
|
277
|
-
# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit
|
|
278
183
|
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
|
279
184
|
model.is_loaded_in_4bit = True
|
|
280
185
|
model.is_4bit_serializable = self.is_serializable()
|
|
281
186
|
return model
|
|
282
187
|
|
|
283
|
-
def is_serializable(self
|
|
188
|
+
def is_serializable(self):
|
|
284
189
|
return True
|
|
285
190
|
|
|
286
191
|
@property
|
|
@@ -290,9 +195,7 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|
|
290
195
|
def _dequantize(self, model):
|
|
291
196
|
from ..integrations import dequantize_and_replace
|
|
292
197
|
|
|
293
|
-
model = dequantize_and_replace(
|
|
294
|
-
model, self.modules_to_not_convert, quantization_config=self.quantization_config
|
|
295
|
-
)
|
|
198
|
+
model = dequantize_and_replace(model, quantization_config=self.quantization_config)
|
|
296
199
|
return model
|
|
297
200
|
|
|
298
201
|
def get_quantize_ops(self):
|