transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -68,13 +68,12 @@ class BitNetHfQuantizer(HfQuantizer):
|
|
|
68
68
|
def _process_model_before_weight_loading(
|
|
69
69
|
self,
|
|
70
70
|
model: "PreTrainedModel",
|
|
71
|
-
keep_in_fp32_modules: list[str] | None = None,
|
|
72
71
|
**kwargs,
|
|
73
72
|
):
|
|
74
73
|
from ..integrations import replace_with_bitnet_linear
|
|
75
74
|
|
|
76
75
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
77
|
-
model, self.quantization_config.modules_to_not_convert,
|
|
76
|
+
model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
|
|
78
77
|
)
|
|
79
78
|
|
|
80
79
|
model = replace_with_bitnet_linear(
|
|
@@ -87,10 +86,6 @@ class BitNetHfQuantizer(HfQuantizer):
|
|
|
87
86
|
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
|
88
87
|
return max_memory
|
|
89
88
|
|
|
90
|
-
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
|
91
|
-
target_dtype = torch.int8
|
|
92
|
-
return target_dtype
|
|
93
|
-
|
|
94
89
|
def is_serializable(self):
|
|
95
90
|
return True
|
|
96
91
|
|
|
@@ -51,15 +51,6 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|
|
51
51
|
def __init__(self, quantization_config, **kwargs):
|
|
52
52
|
super().__init__(quantization_config, **kwargs)
|
|
53
53
|
|
|
54
|
-
# This describes the additional items that are saved on the state dict (on the params themselves)
|
|
55
|
-
self.bnb_keys = [
|
|
56
|
-
f"quant_state.bitsandbytes__{self.quantization_config.bnb_4bit_quant_type}",
|
|
57
|
-
"absmax",
|
|
58
|
-
"quant_map",
|
|
59
|
-
]
|
|
60
|
-
if self.quantization_config.bnb_4bit_use_double_quant:
|
|
61
|
-
self.bnb_keys.extend(["nested_absmax", "nested_quant_map"])
|
|
62
|
-
|
|
63
54
|
def validate_environment(self, *args, **kwargs):
|
|
64
55
|
if not is_accelerate_available():
|
|
65
56
|
raise ImportError(
|
|
@@ -87,55 +78,25 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|
|
87
78
|
"for more details. "
|
|
88
79
|
)
|
|
89
80
|
|
|
90
|
-
def
|
|
91
|
-
|
|
81
|
+
def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
|
|
82
|
+
"Return the element size (in bytes) for `param_name`."
|
|
83
|
+
if self.param_needs_quantization(model, param_name):
|
|
84
|
+
# 4 bit
|
|
85
|
+
return 0.5
|
|
92
86
|
|
|
93
|
-
|
|
94
|
-
logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization")
|
|
95
|
-
return CustomDtype.INT4
|
|
87
|
+
return super().param_element_size(model, param_name, param)
|
|
96
88
|
|
|
97
89
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
98
90
|
import bitsandbytes as bnb
|
|
99
91
|
|
|
100
|
-
# TODO: maybe remove
|
|
101
|
-
# # They are on the params themselves, so we cannot easily extract the module from the name
|
|
102
|
-
if any(param_name.endswith(x) for x in self.bnb_keys):
|
|
103
|
-
return True
|
|
104
92
|
module, name = get_module_from_name(model, param_name)
|
|
105
93
|
return isinstance(module, bnb.nn.Linear4bit) and name != "bias"
|
|
106
94
|
|
|
107
|
-
def get_param_name(self, param_name: str) -> str:
|
|
108
|
-
"""
|
|
109
|
-
Get the right param_name in order to get the module associated with the param.
|
|
110
|
-
This is useful for quantized stats lile absmax or quant_map as we need to update the param_name to get the module as they are stored in ...weight.absmax.
|
|
111
|
-
"""
|
|
112
|
-
if self.pre_quantized:
|
|
113
|
-
# We need to get the param name of quantized weights and not its components. Otherwise, we won't be able to get the nn.Module associated.
|
|
114
|
-
if any(param_name.endswith(x) for x in self.bnb_keys):
|
|
115
|
-
param_name = (
|
|
116
|
-
param_name.rsplit(".", 1)[0] if "quant_state." not in param_name else param_name.rsplit(".", 2)[0]
|
|
117
|
-
)
|
|
118
|
-
return param_name
|
|
119
|
-
|
|
120
95
|
def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
|
|
121
96
|
# need more space for buffers that are created during quantization
|
|
122
97
|
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
|
123
98
|
return max_memory
|
|
124
99
|
|
|
125
|
-
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
126
|
-
# TODO: remove ? is it still true ? we will move to dtype = "auto" so it will likely be either fp16 or bf16
|
|
127
|
-
if dtype is None:
|
|
128
|
-
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
|
|
129
|
-
logger.info(
|
|
130
|
-
"Overriding dtype=%s with `dtype=torch.float16` due to "
|
|
131
|
-
"requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
|
|
132
|
-
"Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
|
|
133
|
-
" dtype=torch.float16 to remove this warning.",
|
|
134
|
-
dtype,
|
|
135
|
-
)
|
|
136
|
-
dtype = torch.float16
|
|
137
|
-
return dtype
|
|
138
|
-
|
|
139
100
|
def update_device_map(self, device_map):
|
|
140
101
|
if device_map is None:
|
|
141
102
|
if torch.cuda.is_available():
|
|
@@ -159,13 +120,12 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|
|
159
120
|
self,
|
|
160
121
|
model: "PreTrainedModel",
|
|
161
122
|
device_map,
|
|
162
|
-
keep_in_fp32_modules: list[str] | None = None,
|
|
163
123
|
**kwargs,
|
|
164
124
|
):
|
|
165
125
|
from ..integrations import replace_with_bnb_linear
|
|
166
126
|
|
|
167
127
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
168
|
-
model, self.quantization_config.llm_int8_skip_modules,
|
|
128
|
+
model, self.quantization_config.llm_int8_skip_modules, model._keep_in_fp32_modules
|
|
169
129
|
)
|
|
170
130
|
|
|
171
131
|
if self.quantization_config.llm_int8_enable_fp32_cpu_offload:
|
|
@@ -192,10 +152,10 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|
|
192
152
|
def is_trainable(self) -> bool:
|
|
193
153
|
return True
|
|
194
154
|
|
|
195
|
-
def _dequantize(self, model):
|
|
155
|
+
def _dequantize(self, model, dtype=None):
|
|
196
156
|
from ..integrations import dequantize_and_replace
|
|
197
157
|
|
|
198
|
-
model = dequantize_and_replace(model, quantization_config=self.quantization_config)
|
|
158
|
+
model = dequantize_and_replace(model, quantization_config=self.quantization_config, dtype=dtype)
|
|
199
159
|
return model
|
|
200
160
|
|
|
201
161
|
def get_quantize_ops(self):
|
|
@@ -83,19 +83,6 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
|
|
83
83
|
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
|
84
84
|
return max_memory
|
|
85
85
|
|
|
86
|
-
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
87
|
-
if dtype is None:
|
|
88
|
-
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
|
|
89
|
-
logger.info(
|
|
90
|
-
"Overriding dtype=%s with `dtype=torch.float16` due to "
|
|
91
|
-
"requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
|
|
92
|
-
"Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
|
|
93
|
-
" dtype=torch.float16 to remove this warning.",
|
|
94
|
-
dtype,
|
|
95
|
-
)
|
|
96
|
-
dtype = torch.float16
|
|
97
|
-
return dtype
|
|
98
|
-
|
|
99
86
|
def update_device_map(self, device_map):
|
|
100
87
|
if device_map is None:
|
|
101
88
|
if torch.cuda.is_available():
|
|
@@ -115,8 +102,12 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
|
|
115
102
|
)
|
|
116
103
|
return device_map
|
|
117
104
|
|
|
118
|
-
def
|
|
119
|
-
|
|
105
|
+
def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
|
|
106
|
+
"Return the element size (in bytes) for `param_name`."
|
|
107
|
+
if self.param_needs_quantization(model, param_name):
|
|
108
|
+
# 8-bit
|
|
109
|
+
return 1
|
|
110
|
+
return super().param_element_size(model, param_name, param)
|
|
120
111
|
|
|
121
112
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
122
113
|
import bitsandbytes as bnb
|
|
@@ -133,13 +124,12 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
|
|
133
124
|
self,
|
|
134
125
|
model: "PreTrainedModel",
|
|
135
126
|
device_map,
|
|
136
|
-
keep_in_fp32_modules: list[str] | None = None,
|
|
137
127
|
**kwargs,
|
|
138
128
|
):
|
|
139
129
|
from ..integrations import replace_with_bnb_linear
|
|
140
130
|
|
|
141
131
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
142
|
-
model, self.quantization_config.llm_int8_skip_modules,
|
|
132
|
+
model, self.quantization_config.llm_int8_skip_modules, model._keep_in_fp32_modules
|
|
143
133
|
)
|
|
144
134
|
|
|
145
135
|
if self.quantization_config.llm_int8_enable_fp32_cpu_offload:
|
|
@@ -161,10 +151,10 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
|
|
161
151
|
def is_trainable(self) -> bool:
|
|
162
152
|
return True
|
|
163
153
|
|
|
164
|
-
def _dequantize(self, model):
|
|
154
|
+
def _dequantize(self, model, dtype=None):
|
|
165
155
|
from ..integrations import dequantize_and_replace
|
|
166
156
|
|
|
167
|
-
model = dequantize_and_replace(model, quantization_config=self.quantization_config)
|
|
157
|
+
model = dequantize_and_replace(model, quantization_config=self.quantization_config, dtype=dtype)
|
|
168
158
|
return model
|
|
169
159
|
|
|
170
160
|
def get_quantize_ops(self):
|
|
@@ -59,10 +59,7 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
|
|
|
59
59
|
)
|
|
60
60
|
|
|
61
61
|
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
62
|
-
if dtype
|
|
63
|
-
logger.info("Loading model using torch.float16 for compressed-tensors quantization")
|
|
64
|
-
dtype = torch.float16
|
|
65
|
-
elif dtype != torch.float16:
|
|
62
|
+
if dtype != torch.float16:
|
|
66
63
|
logger.info("We suggest you to set `dtype=torch.float16` for better efficiency with compressed_tensors.")
|
|
67
64
|
return dtype
|
|
68
65
|
|
|
@@ -64,16 +64,7 @@ class EetqHfQuantizer(HfQuantizer):
|
|
|
64
64
|
)
|
|
65
65
|
|
|
66
66
|
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
67
|
-
if dtype
|
|
68
|
-
dtype = torch.float16
|
|
69
|
-
logger.info(
|
|
70
|
-
"Overriding dtype=%s with `dtype=torch.float16` due to "
|
|
71
|
-
"requirements of `eetq` to enable model loading in 8-bit. "
|
|
72
|
-
"Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
|
|
73
|
-
" dtype=torch.float16 to remove this warning.",
|
|
74
|
-
dtype,
|
|
75
|
-
)
|
|
76
|
-
elif dtype != torch.float16:
|
|
67
|
+
if dtype != torch.float16:
|
|
77
68
|
logger.info("We suggest you to set `dtype=torch.float16` for better efficiency with EETQ.")
|
|
78
69
|
return dtype
|
|
79
70
|
|
|
@@ -92,13 +83,12 @@ class EetqHfQuantizer(HfQuantizer):
|
|
|
92
83
|
def _process_model_before_weight_loading(
|
|
93
84
|
self,
|
|
94
85
|
model: "PreTrainedModel",
|
|
95
|
-
keep_in_fp32_modules: list[str] | None = None,
|
|
96
86
|
**kwargs,
|
|
97
87
|
):
|
|
98
88
|
from ..integrations import replace_with_eetq_linear
|
|
99
89
|
|
|
100
90
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
101
|
-
model, self.quantization_config.modules_to_not_convert,
|
|
91
|
+
model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
|
|
102
92
|
)
|
|
103
93
|
|
|
104
94
|
model = replace_with_eetq_linear(
|
|
@@ -84,19 +84,11 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
|
|
84
84
|
)
|
|
85
85
|
|
|
86
86
|
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
87
|
-
if dtype
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
"Overriding dtype=%s with `dtype=torch.bloat16` due to "
|
|
91
|
-
"requirements of `fbgemm-gpu` to enable model loading in fp8. "
|
|
92
|
-
"Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
|
|
93
|
-
" dtype=torch.bfloat16 to remove this warning.",
|
|
94
|
-
dtype,
|
|
95
|
-
)
|
|
96
|
-
elif dtype == torch.float16:
|
|
97
|
-
raise ValueError(
|
|
98
|
-
"You cannot use FP8 with dtype=torch.float16. We recommend you passing dtype=torch.bfloat16"
|
|
87
|
+
if dtype != torch.bfloat16:
|
|
88
|
+
logger.warning_once(
|
|
89
|
+
f"Setting dtype to {dtype}, but only bfloat16 is supported right now. Overwriting torch_dtype to bfloat16."
|
|
99
90
|
)
|
|
91
|
+
dtype = torch.bfloat16
|
|
100
92
|
return dtype
|
|
101
93
|
|
|
102
94
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
@@ -119,13 +111,12 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
|
|
119
111
|
def _process_model_before_weight_loading(
|
|
120
112
|
self,
|
|
121
113
|
model: "PreTrainedModel",
|
|
122
|
-
keep_in_fp32_modules: list[str] | None = None,
|
|
123
114
|
**kwargs,
|
|
124
115
|
):
|
|
125
116
|
from ..integrations import replace_with_fbgemm_fp8_linear
|
|
126
117
|
|
|
127
118
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
128
|
-
model, self.quantization_config.modules_to_not_convert,
|
|
119
|
+
model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
|
|
129
120
|
)
|
|
130
121
|
|
|
131
122
|
model = replace_with_fbgemm_fp8_linear(
|
|
@@ -33,7 +33,7 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
|
|
|
33
33
|
return
|
|
34
34
|
|
|
35
35
|
if not torch.cuda.is_available() and not is_torch_xpu_available():
|
|
36
|
-
if self.pre_quantized
|
|
36
|
+
if self.pre_quantized:
|
|
37
37
|
logger.warning_once(
|
|
38
38
|
"Using FP8 quantized models requires a GPU or XPU, we will default to dequantizing the model to bf16 since no GPU or XPU is available"
|
|
39
39
|
)
|
|
@@ -46,10 +46,13 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
|
|
|
46
46
|
compute_capability = torch.cuda.get_device_capability()
|
|
47
47
|
major, minor = compute_capability
|
|
48
48
|
if (major < 8) or (major == 8 and minor < 9):
|
|
49
|
-
|
|
49
|
+
logger.warning_once(
|
|
50
50
|
"FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)"
|
|
51
|
-
f", actual = `{major}.{minor}
|
|
51
|
+
f", actual = `{major}.{minor}`. We will default to dequantizing the model to bf16. Feel free "
|
|
52
|
+
f"to use a different quantization method like bitsandbytes or torchao"
|
|
52
53
|
)
|
|
54
|
+
self.quantization_config.dequantize = True
|
|
55
|
+
return
|
|
53
56
|
|
|
54
57
|
device_map = kwargs.get("device_map")
|
|
55
58
|
if device_map is None:
|
|
@@ -82,16 +85,22 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
|
|
|
82
85
|
return True
|
|
83
86
|
return False
|
|
84
87
|
|
|
88
|
+
def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
|
|
89
|
+
"Return the element size (in bytes) for `param_name`."
|
|
90
|
+
if self.param_needs_quantization(model, param_name):
|
|
91
|
+
# 8 bit, this is neeed as when `pre_quantized`` is False, we don't set the dtype of the FP8Linear in order to correctly load the weights
|
|
92
|
+
return 1
|
|
93
|
+
return super().param_element_size(model, param_name, param)
|
|
94
|
+
|
|
85
95
|
def _process_model_before_weight_loading(
|
|
86
96
|
self,
|
|
87
97
|
model: "PreTrainedModel",
|
|
88
|
-
keep_in_fp32_modules: list[str] | None = None,
|
|
89
98
|
**kwargs,
|
|
90
99
|
):
|
|
91
100
|
from ..integrations.finegrained_fp8 import replace_with_fp8_linear
|
|
92
101
|
|
|
93
102
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
94
|
-
model, self.quantization_config.modules_to_not_convert,
|
|
103
|
+
model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
|
|
95
104
|
)
|
|
96
105
|
|
|
97
106
|
model = replace_with_fp8_linear(
|
|
@@ -103,7 +112,7 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
|
|
|
103
112
|
|
|
104
113
|
# NOTE: TP is applied before quantization so this is only to add hooks.
|
|
105
114
|
# Quantization is incompatible with DTensors, so we have to anyway have
|
|
106
|
-
# gathers! But it should be model
|
|
115
|
+
# gathers! But it should be model independent -> figure out where to put
|
|
107
116
|
# the gather and that's it.
|
|
108
117
|
def update_tp_plan(self, config):
|
|
109
118
|
if "Qwen3" in config.__class__.__name__:
|
|
@@ -137,10 +146,6 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
|
|
|
137
146
|
def is_trainable(self) -> bool:
|
|
138
147
|
return False
|
|
139
148
|
|
|
140
|
-
def get_accelerator_warm_up_factor(self):
|
|
141
|
-
# Pre-processing is done cleanly, so we can allocate everything here
|
|
142
|
-
return 2
|
|
143
|
-
|
|
144
149
|
def get_quantize_ops(self):
|
|
145
150
|
from ..integrations.finegrained_fp8 import Fp8Quantize
|
|
146
151
|
|
|
@@ -78,11 +78,11 @@ class FPQuantHfQuantizer(HfQuantizer):
|
|
|
78
78
|
)
|
|
79
79
|
|
|
80
80
|
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
81
|
-
if dtype
|
|
82
|
-
logger.
|
|
81
|
+
if dtype != torch.bfloat16:
|
|
82
|
+
logger.warning_once(
|
|
83
|
+
f"Setting dtype to {dtype}, but only bfloat16 is supported right now. Overwriting torch_dtype to bfloat16."
|
|
84
|
+
)
|
|
83
85
|
dtype = torch.bfloat16
|
|
84
|
-
elif dtype != torch.bfloat16:
|
|
85
|
-
raise ValueError(f"Invalid `dtype` {dtype}. fp_quant quantization only supports `dtype=torch.bfloat16`.")
|
|
86
86
|
return dtype
|
|
87
87
|
|
|
88
88
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
@@ -66,10 +66,7 @@ class GptqHfQuantizer(HfQuantizer):
|
|
|
66
66
|
raise ImportError("The gptqmodel version should be >= 1.4.3, optimum version should >= 1.24.0")
|
|
67
67
|
|
|
68
68
|
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
69
|
-
if dtype
|
|
70
|
-
dtype = torch.float16
|
|
71
|
-
logger.info("Loading the model in `torch.float16`. To overwrite it, set `dtype` manually.")
|
|
72
|
-
elif dtype != torch.float16:
|
|
69
|
+
if dtype != torch.float16:
|
|
73
70
|
logger.info("We suggest you to set `dtype=torch.float16` for better efficiency with GPTQ.")
|
|
74
71
|
return dtype
|
|
75
72
|
|
|
@@ -69,10 +69,7 @@ class HiggsHfQuantizer(HfQuantizer):
|
|
|
69
69
|
)
|
|
70
70
|
|
|
71
71
|
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
72
|
-
if dtype
|
|
73
|
-
logger.info("`dtype` is None. Setting `dtype=torch.float16` for FLUTE compatibility.")
|
|
74
|
-
dtype = torch.float16
|
|
75
|
-
elif dtype != torch.float16 and dtype != torch.bfloat16:
|
|
72
|
+
if dtype != torch.float16 and dtype != torch.bfloat16:
|
|
76
73
|
raise ValueError(
|
|
77
74
|
f"Invalid `dtype` {dtype}. HIGGS quantization only supports `dtype=torch.float16` or `dtype=torch.bfloat16`."
|
|
78
75
|
)
|
|
@@ -116,13 +113,12 @@ class HiggsHfQuantizer(HfQuantizer):
|
|
|
116
113
|
def _process_model_before_weight_loading(
|
|
117
114
|
self,
|
|
118
115
|
model: "PreTrainedModel",
|
|
119
|
-
keep_in_fp32_modules: list[str] | None = None,
|
|
120
116
|
**kwargs,
|
|
121
117
|
):
|
|
122
118
|
from ..integrations import replace_with_higgs_linear
|
|
123
119
|
|
|
124
120
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
125
|
-
model, self.quantization_config.modules_to_not_convert,
|
|
121
|
+
model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
|
|
126
122
|
)
|
|
127
123
|
|
|
128
124
|
replace_with_higgs_linear(
|
|
@@ -53,7 +53,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|
|
53
53
|
"""Lazy import and initialize kernels only when needed"""
|
|
54
54
|
if self.triton_kernels_hub is None:
|
|
55
55
|
try:
|
|
56
|
-
from
|
|
56
|
+
from ..integrations.hub_kernels import get_kernel
|
|
57
57
|
|
|
58
58
|
self.triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
|
|
59
59
|
except ImportError:
|
|
@@ -135,18 +135,6 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|
|
135
135
|
"Please use a quantized checkpoint or remove the CPU or disk device from the device_map."
|
|
136
136
|
)
|
|
137
137
|
|
|
138
|
-
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
139
|
-
if dtype is None:
|
|
140
|
-
dtype = torch.bfloat16
|
|
141
|
-
logger.info(
|
|
142
|
-
"Overriding dtype=%s with `dtype=torch.bfloat16` due to "
|
|
143
|
-
"requirements of `fbgemm-gpu` to enable model loading in fp4. "
|
|
144
|
-
"Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
|
|
145
|
-
" dtype=torch.bfloat16 to remove this warning.",
|
|
146
|
-
dtype,
|
|
147
|
-
)
|
|
148
|
-
return dtype
|
|
149
|
-
|
|
150
138
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
151
139
|
from ..integrations import Mxfp4GptOssExperts
|
|
152
140
|
|
|
@@ -167,7 +155,6 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|
|
167
155
|
def _process_model_before_weight_loading(
|
|
168
156
|
self,
|
|
169
157
|
model: "PreTrainedModel",
|
|
170
|
-
keep_in_fp32_modules: list[str] | None = None,
|
|
171
158
|
use_kernels: bool = False,
|
|
172
159
|
**kwargs,
|
|
173
160
|
):
|
|
@@ -182,7 +169,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|
|
182
169
|
self.quantization_config.dequantize = True
|
|
183
170
|
|
|
184
171
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
185
|
-
model, self.quantization_config.modules_to_not_convert,
|
|
172
|
+
model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
|
|
186
173
|
)
|
|
187
174
|
|
|
188
175
|
model = replace_with_mxfp4_linear(
|
|
@@ -215,19 +202,6 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|
|
215
202
|
)
|
|
216
203
|
return config
|
|
217
204
|
|
|
218
|
-
def get_param_name(self, param_name: str) -> str:
|
|
219
|
-
if self.quantization_config.dequantize:
|
|
220
|
-
if "_blocks" in param_name:
|
|
221
|
-
return param_name.replace("_blocks", "")
|
|
222
|
-
elif "_scales" in param_name:
|
|
223
|
-
return param_name.replace("_scales", "")
|
|
224
|
-
elif not self.pre_quantized:
|
|
225
|
-
if param_name.endswith("gate_up_proj"):
|
|
226
|
-
return param_name.replace("gate_up_proj", "gate_up_proj_blocks")
|
|
227
|
-
if param_name.endswith("down_proj"):
|
|
228
|
-
return param_name.replace("down_proj", "down_proj_blocks")
|
|
229
|
-
return param_name
|
|
230
|
-
|
|
231
205
|
def get_state_dict_and_metadata(self, model):
|
|
232
206
|
from ..integrations import Mxfp4GptOssExperts
|
|
233
207
|
|
|
@@ -44,6 +44,13 @@ class QuantoHfQuantizer(HfQuantizer):
|
|
|
44
44
|
|
|
45
45
|
def __init__(self, quantization_config: QuantoConfig, **kwargs):
|
|
46
46
|
super().__init__(quantization_config, **kwargs)
|
|
47
|
+
map_to_param_size = {
|
|
48
|
+
"int8": 1,
|
|
49
|
+
"float8": 1,
|
|
50
|
+
"int4": 0.5,
|
|
51
|
+
"int2": 0.25,
|
|
52
|
+
}
|
|
53
|
+
self.quantized_param_size = map_to_param_size.get(self.quantization_config.weights, None)
|
|
47
54
|
|
|
48
55
|
def validate_environment(self, *args, **kwargs):
|
|
49
56
|
if not is_optimum_quanto_available():
|
|
@@ -83,25 +90,18 @@ class QuantoHfQuantizer(HfQuantizer):
|
|
|
83
90
|
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
|
84
91
|
return max_memory
|
|
85
92
|
|
|
86
|
-
def
|
|
87
|
-
|
|
93
|
+
def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
|
|
94
|
+
"Return the element size (in bytes) for `param_name`."
|
|
95
|
+
if self.param_needs_quantization(model, param_name) and self.quantized_param_size is not None:
|
|
96
|
+
return self.quantized_param_size
|
|
88
97
|
|
|
89
|
-
|
|
90
|
-
"int8": torch.int8,
|
|
91
|
-
"float8": CustomDtype.FP8,
|
|
92
|
-
"int4": CustomDtype.INT4,
|
|
93
|
-
"int2": CustomDtype.INT2,
|
|
94
|
-
}
|
|
95
|
-
target_dtype = mapping[self.quantization_config.weights]
|
|
96
|
-
return target_dtype
|
|
98
|
+
return super().param_element_size(model, param_name, param)
|
|
97
99
|
|
|
98
|
-
def _process_model_before_weight_loading(
|
|
99
|
-
self, model: "PreTrainedModel", keep_in_fp32_modules: list[str] | None = None, **kwargs
|
|
100
|
-
):
|
|
100
|
+
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
|
101
101
|
from ..integrations import replace_with_quanto_layers
|
|
102
102
|
|
|
103
103
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
104
|
-
model, self.quantization_config.modules_to_not_convert,
|
|
104
|
+
model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
|
|
105
105
|
)
|
|
106
106
|
|
|
107
107
|
model = replace_with_quanto_layers(
|
|
@@ -51,24 +51,19 @@ class SpQRHfQuantizer(HfQuantizer):
|
|
|
51
51
|
raise ImportError("Using `spqr` quantization requires SpQR: `pip install spqr_quant[gpu]`")
|
|
52
52
|
|
|
53
53
|
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
54
|
-
if dtype
|
|
55
|
-
dtype = torch.float16
|
|
56
|
-
logger.info("Assuming SpQR inference on GPU and loading the model in `torch.float16`.")
|
|
57
|
-
elif dtype != torch.float16:
|
|
54
|
+
if dtype != torch.float16:
|
|
58
55
|
raise ValueError(
|
|
59
|
-
"You cannot use any type other than torch.float16 for SpQR. Please
|
|
60
|
-
"torch.float16 explicitly."
|
|
56
|
+
"You cannot use any type other than torch.float16 for SpQR. Please set it totorch.float16 explicitly."
|
|
61
57
|
)
|
|
62
58
|
return dtype
|
|
63
59
|
|
|
64
60
|
def _process_model_before_weight_loading(
|
|
65
61
|
self,
|
|
66
62
|
model: "PreTrainedModel",
|
|
67
|
-
keep_in_fp32_modules: list[str] | None = None,
|
|
68
63
|
**kwargs,
|
|
69
64
|
):
|
|
70
65
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
71
|
-
model, self.quantization_config.modules_to_not_convert,
|
|
66
|
+
model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
|
|
72
67
|
)
|
|
73
68
|
replace_with_spqr_linear(
|
|
74
69
|
model,
|