transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -12,7 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from collections import defaultdict
|
|
16
15
|
from typing import TYPE_CHECKING
|
|
17
16
|
|
|
18
17
|
from ..integrations import prepare_for_hqq_linear
|
|
@@ -49,10 +48,7 @@ class HqqHfQuantizer(HfQuantizer):
|
|
|
49
48
|
nn.Linear modules are first tagged with quant_config in _process_model_before_weight_loading().
|
|
50
49
|
"""
|
|
51
50
|
|
|
52
|
-
use_keep_in_fp32_modules = False
|
|
53
|
-
requires_parameters_quantization = True
|
|
54
51
|
requires_calibration = False
|
|
55
|
-
required_packages = ["hqq"]
|
|
56
52
|
|
|
57
53
|
def __init__(self, quantization_config, **kwargs):
|
|
58
54
|
if not is_hqq_available():
|
|
@@ -83,73 +79,67 @@ class HqqHfQuantizer(HfQuantizer):
|
|
|
83
79
|
else:
|
|
84
80
|
self.using_multi_gpu = len(set(device_map.values())) > 1
|
|
85
81
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
else:
|
|
148
|
-
new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
|
|
149
|
-
if _module + ".bias" in loaded_keys:
|
|
150
|
-
new_keys.add(_module + ".bias")
|
|
151
|
-
|
|
152
|
-
return list(new_keys)
|
|
82
|
+
# TODO: to remove
|
|
83
|
+
# Kept here in case we see some interest in adding support for it
|
|
84
|
+
# # Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
|
|
85
|
+
# def update_expected_keys(
|
|
86
|
+
# self, model: "PreTrainedModel", expected_keys: list[str], loaded_keys: list[str]
|
|
87
|
+
# ) -> list[str]:
|
|
88
|
+
# if not self.pre_quantized:
|
|
89
|
+
# return expected_keys
|
|
90
|
+
|
|
91
|
+
# # Collects all quantizable (linear) layers
|
|
92
|
+
# def _find_hqq_quantizable_layers(model, layers):
|
|
93
|
+
# for name, module in model.named_children():
|
|
94
|
+
# if isinstance(module, (torch.nn.Linear)):
|
|
95
|
+
# layers.add(module.name)
|
|
96
|
+
# _find_hqq_quantizable_layers(module, layers)
|
|
97
|
+
|
|
98
|
+
# new_keys = set(expected_keys)
|
|
99
|
+
|
|
100
|
+
# # Name modules
|
|
101
|
+
# for name, module in model.named_modules():
|
|
102
|
+
# module.name = name
|
|
103
|
+
|
|
104
|
+
# # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
|
|
105
|
+
# _valid_modules = set()
|
|
106
|
+
# _find_hqq_quantizable_layers(model, _valid_modules)
|
|
107
|
+
|
|
108
|
+
# # Remove skipped modules
|
|
109
|
+
# _skipped_modules = set()
|
|
110
|
+
# for _module in _valid_modules:
|
|
111
|
+
# for _skip_module in model.config.quantization_config["skip_modules"]:
|
|
112
|
+
# if _skip_module in _module:
|
|
113
|
+
# _skipped_modules.add(_module)
|
|
114
|
+
# _valid_modules -= _skipped_modules
|
|
115
|
+
|
|
116
|
+
# # Append new expected layers based on _ref_keys
|
|
117
|
+
# _ref_keys = HQQLinear(
|
|
118
|
+
# linear_layer=None,
|
|
119
|
+
# quant_config=None,
|
|
120
|
+
# compute_dtype=torch.float16,
|
|
121
|
+
# device="cpu",
|
|
122
|
+
# del_orig=False,
|
|
123
|
+
# ).state_dict_keys() - {"bias"}
|
|
124
|
+
|
|
125
|
+
# # Clean-up
|
|
126
|
+
# _rm_keys = set()
|
|
127
|
+
# for key in new_keys:
|
|
128
|
+
# if any(_module in key for _module in _valid_modules):
|
|
129
|
+
# _rm_keys.add(key)
|
|
130
|
+
# new_keys -= _rm_keys
|
|
131
|
+
# # At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
|
|
132
|
+
|
|
133
|
+
# # Re-populate Linear/HQQLinear
|
|
134
|
+
# for _module in _valid_modules:
|
|
135
|
+
# if _module + ".weight" in loaded_keys:
|
|
136
|
+
# new_keys.add(_module + ".weight")
|
|
137
|
+
# else:
|
|
138
|
+
# new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
|
|
139
|
+
# if _module + ".bias" in loaded_keys:
|
|
140
|
+
# new_keys.add(_module + ".bias")
|
|
141
|
+
|
|
142
|
+
# return list(new_keys)
|
|
153
143
|
|
|
154
144
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
155
145
|
module, _ = get_module_from_name(model, param_name)
|
|
@@ -157,87 +147,88 @@ class HqqHfQuantizer(HfQuantizer):
|
|
|
157
147
|
# `create_quantized_param`, even when `self.is_quantized == True`
|
|
158
148
|
return isinstance(module, torch.nn.Linear)
|
|
159
149
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
150
|
+
# TODO: to remove
|
|
151
|
+
# def create_quantized_param(
|
|
152
|
+
# self,
|
|
153
|
+
# model: "PreTrainedModel",
|
|
154
|
+
# param_value: "torch.Tensor",
|
|
155
|
+
# param_name: str,
|
|
156
|
+
# target_device: "torch.device",
|
|
157
|
+
# **kwargs,
|
|
158
|
+
# ):
|
|
159
|
+
# module, tensor_name = get_module_from_name(model, param_name)
|
|
160
|
+
# module_name = param_name.rsplit(".", 1)[0]
|
|
161
|
+
# parent_module, node = get_module_from_name(model, module_name)
|
|
162
|
+
|
|
163
|
+
# quant_config = model.config.quantization_config["quant_config"]
|
|
164
|
+
# skip_modules = model.config.quantization_config["skip_modules"]
|
|
165
|
+
|
|
166
|
+
# # In this case we do not quantize this layer (it's explicitly skipped) -> simply load param
|
|
167
|
+
# if any(skip_module in module.name for skip_module in skip_modules):
|
|
168
|
+
# module.load_state_dict(
|
|
169
|
+
# {tensor_name: param_value.to(device=target_device, dtype=self.dtype)}, strict=False, assign=True
|
|
170
|
+
# )
|
|
171
|
+
# return
|
|
172
|
+
|
|
173
|
+
# # We need this hack as the model is not pre-prepared as an empty skeleton on meta device
|
|
174
|
+
# if self.pre_quantized:
|
|
175
|
+
# # Save them for later
|
|
176
|
+
# if not hasattr(self, "hqq_params"):
|
|
177
|
+
# self.hqq_params = defaultdict(dict)
|
|
178
|
+
# self.hqq_params[module_name].update({tensor_name: param_value})
|
|
179
|
+
# hqq_params = self.hqq_params[module_name]
|
|
180
|
+
|
|
181
|
+
# # If they are all present and saved, make it a HQQLinear layer! (we cannot do it param after param because
|
|
182
|
+
# # hqq does not support it...)
|
|
183
|
+
# if all(k in hqq_params for k in self.hqq_keys) and ("bias" in hqq_params or module.bias is None):
|
|
184
|
+
# hqq_layer = HQQLinear(
|
|
185
|
+
# linear_layer=None,
|
|
186
|
+
# quant_config=None,
|
|
187
|
+
# compute_dtype=self.dtype,
|
|
188
|
+
# device=target_device,
|
|
189
|
+
# del_orig=False,
|
|
190
|
+
# )
|
|
191
|
+
# hqq_layer.load_state_dict(hqq_params)
|
|
192
|
+
|
|
193
|
+
# if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
|
|
194
|
+
# hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
|
|
195
|
+
# if self.using_multi_gpu:
|
|
196
|
+
# hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
|
|
197
|
+
|
|
198
|
+
# setattr(parent_module, node, hqq_layer)
|
|
199
|
+
# del self.hqq_params[module_name], module
|
|
200
|
+
# return
|
|
201
|
+
|
|
202
|
+
# # Load param in the module (without caring about device or dtype, it will be changed later)
|
|
203
|
+
# module.load_state_dict({tensor_name: param_value}, strict=False, assign=True)
|
|
204
|
+
|
|
205
|
+
# # If both the weight and bias have already been loaded, time to quantize!
|
|
206
|
+
# module_is_ready = module.weight.device.type != "meta" and (
|
|
207
|
+
# module.bias is None or module.bias.device.type != "meta"
|
|
208
|
+
# )
|
|
209
|
+
|
|
210
|
+
# if module_is_ready:
|
|
211
|
+
# module_tag = ".".join(module.name.split(".")[-2:])
|
|
212
|
+
# if "weight_quant_params" in quant_config:
|
|
213
|
+
# module_quant_config = quant_config
|
|
214
|
+
# elif module_tag in quant_config:
|
|
215
|
+
# module_quant_config = quant_config[module_tag]
|
|
216
|
+
|
|
217
|
+
# hqq_layer = HQQLinear(
|
|
218
|
+
# module,
|
|
219
|
+
# quant_config=module_quant_config,
|
|
220
|
+
# compute_dtype=self.dtype,
|
|
221
|
+
# device=target_device,
|
|
222
|
+
# del_orig=True,
|
|
223
|
+
# )
|
|
224
|
+
|
|
225
|
+
# if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
|
|
226
|
+
# hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
|
|
227
|
+
|
|
228
|
+
# if self.using_multi_gpu:
|
|
229
|
+
# hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
|
|
230
|
+
|
|
231
|
+
# setattr(parent_module, node, hqq_layer)
|
|
241
232
|
|
|
242
233
|
def _patch_layer_for_multigpu(self, hqq_layer):
|
|
243
234
|
def forward_with_device(self, x):
|
|
@@ -263,7 +254,7 @@ class HqqHfQuantizer(HfQuantizer):
|
|
|
263
254
|
model.is_hqq_serializable = self.is_serializable()
|
|
264
255
|
return model
|
|
265
256
|
|
|
266
|
-
def is_serializable(self
|
|
257
|
+
def is_serializable(self):
|
|
267
258
|
return True
|
|
268
259
|
|
|
269
260
|
@property
|
|
@@ -43,14 +43,10 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|
|
43
43
|
FP4 quantization using fbgemm kernels
|
|
44
44
|
"""
|
|
45
45
|
|
|
46
|
-
requires_parameters_quantization = True
|
|
47
46
|
requires_calibration = False
|
|
48
47
|
|
|
49
|
-
required_packages = ["accelerate"]
|
|
50
|
-
|
|
51
48
|
def __init__(self, quantization_config, **kwargs):
|
|
52
49
|
super().__init__(quantization_config, **kwargs)
|
|
53
|
-
self.quantization_config = quantization_config
|
|
54
50
|
self.triton_kernels_hub = None
|
|
55
51
|
|
|
56
52
|
def _lazy_import_kernels(self):
|
|
@@ -74,7 +70,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|
|
74
70
|
if self.quantization_config.dequantize:
|
|
75
71
|
return
|
|
76
72
|
|
|
77
|
-
if not
|
|
73
|
+
if not torch.cuda.is_available() and not torch.xpu.is_available():
|
|
78
74
|
if self.pre_quantized:
|
|
79
75
|
logger.warning_once(
|
|
80
76
|
"Using MXFP4 quantized models requires a GPU, we will default to dequantizing the model to bf16"
|
|
@@ -131,12 +127,8 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|
|
131
127
|
"You have loaded an FP4 model on CPU and have a CUDA/XPU device available, make sure to set "
|
|
132
128
|
"your model on a GPU/XPU device in order to run your model. To remove this warning, pass device_map = 'cuda' or device_map = 'xpu'. "
|
|
133
129
|
)
|
|
134
|
-
elif device_map
|
|
135
|
-
if (
|
|
136
|
-
not self.pre_quantized
|
|
137
|
-
and isinstance(device_map, dict)
|
|
138
|
-
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
|
139
|
-
):
|
|
130
|
+
elif isinstance(device_map, dict):
|
|
131
|
+
if not self.pre_quantized and ("cpu" in device_map.values() or "disk" in device_map.values()):
|
|
140
132
|
raise ValueError(
|
|
141
133
|
"You are attempting to load an FP4 model with a device_map that contains a CPU or disk device."
|
|
142
134
|
"This is not supported when the model is quantized on the fly. "
|
|
@@ -157,159 +149,30 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|
|
157
149
|
|
|
158
150
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
159
151
|
from ..integrations import Mxfp4GptOssExperts
|
|
160
|
-
from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
|
|
161
152
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
# if we are dequantizing, the model doesn't have scales, and blocks only params like gate_up_proj and down_proj so we need to handle this case differently
|
|
165
|
-
if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name):
|
|
166
|
-
module, tensor_name = get_module_from_name(model, param_name[: -len("_blocks")])
|
|
167
|
-
else:
|
|
168
|
-
module, tensor_name = get_module_from_name(model, param_name)
|
|
169
|
-
if isinstance(module, Mxfp4GptOssExperts) or (
|
|
170
|
-
isinstance(module, GptOssExperts) and self.quantization_config.dequantize
|
|
171
|
-
):
|
|
153
|
+
module, tensor_name = get_module_from_name(model, param_name)
|
|
154
|
+
if isinstance(module, Mxfp4GptOssExperts):
|
|
172
155
|
if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]:
|
|
173
156
|
return False
|
|
174
157
|
return True
|
|
175
158
|
return False
|
|
176
159
|
|
|
177
|
-
def create_quantized_param(
|
|
178
|
-
self,
|
|
179
|
-
model: "PreTrainedModel",
|
|
180
|
-
param_value: "torch.Tensor",
|
|
181
|
-
param_name: str,
|
|
182
|
-
target_device: "torch.device",
|
|
183
|
-
**kwargs,
|
|
184
|
-
):
|
|
185
|
-
from ..integrations import (
|
|
186
|
-
Mxfp4GptOssExperts,
|
|
187
|
-
dequantize,
|
|
188
|
-
load_and_swizzle_mxfp4,
|
|
189
|
-
quantize_to_mxfp4,
|
|
190
|
-
swizzle_mxfp4,
|
|
191
|
-
)
|
|
192
|
-
from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
|
|
193
|
-
|
|
194
|
-
if not self.pre_quantized:
|
|
195
|
-
triton_kernels_hub = self._lazy_import_kernels()
|
|
196
|
-
module, _ = get_module_from_name(model, param_name)
|
|
197
|
-
with torch.device(target_device):
|
|
198
|
-
if isinstance(module, Mxfp4GptOssExperts):
|
|
199
|
-
triton_weight_tensor, weight_scale = quantize_to_mxfp4(param_value, triton_kernels_hub)
|
|
200
|
-
PrecisionConfig, FlexCtx, InFlexData = (
|
|
201
|
-
triton_kernels_hub.matmul_ogs.PrecisionConfig,
|
|
202
|
-
triton_kernels_hub.matmul_ogs.FlexCtx,
|
|
203
|
-
triton_kernels_hub.matmul_ogs.InFlexData,
|
|
204
|
-
)
|
|
205
|
-
triton_weight_tensor, weight_scale = swizzle_mxfp4(
|
|
206
|
-
triton_weight_tensor, weight_scale, triton_kernels_hub
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
proj = "gate_up_proj" if "gate_up_proj" in param_name else "down_proj"
|
|
210
|
-
setattr(module, proj, triton_weight_tensor)
|
|
211
|
-
setattr(
|
|
212
|
-
module,
|
|
213
|
-
f"{proj}_precision_config",
|
|
214
|
-
PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
delattr(module, f"{proj}_blocks")
|
|
218
|
-
delattr(module, f"{proj}_scales")
|
|
219
|
-
|
|
220
|
-
# The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales
|
|
221
|
-
else:
|
|
222
|
-
# This is when loading a quantized model (blocks and scales exist)
|
|
223
|
-
empty_param = kwargs.get("empty_param")
|
|
224
|
-
casting_dtype = kwargs.get("casting_dtype")
|
|
225
|
-
to_contiguous = kwargs.get("to_contiguous")
|
|
226
|
-
rank = kwargs.get("rank")
|
|
227
|
-
device_mesh = kwargs.get("device_mesh")
|
|
228
|
-
if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize:
|
|
229
|
-
# blocks and scales have the same length that's why this works for both
|
|
230
|
-
module, _ = get_module_from_name(model, param_name[: -len("_blocks")])
|
|
231
|
-
else:
|
|
232
|
-
module, _ = get_module_from_name(model, param_name)
|
|
233
|
-
|
|
234
|
-
shard_kwargs = {
|
|
235
|
-
"empty_param": empty_param,
|
|
236
|
-
"casting_dtype": casting_dtype,
|
|
237
|
-
"to_contiguous": to_contiguous,
|
|
238
|
-
"rank": rank,
|
|
239
|
-
"device_mesh": device_mesh,
|
|
240
|
-
"model": model,
|
|
241
|
-
}
|
|
242
|
-
|
|
243
|
-
if isinstance(module, Mxfp4GptOssExperts) or (
|
|
244
|
-
isinstance(module, GptOssExperts) and self.quantization_config.dequantize
|
|
245
|
-
):
|
|
246
|
-
if self.quantization_config.dequantize:
|
|
247
|
-
# dq_param_name is the name of the parameter without the blocks or scales suffix, it's used in this case since we don't switch linears
|
|
248
|
-
# so we only have the original param name
|
|
249
|
-
dq_param_name = param_name[: -len("_blocks")]
|
|
250
|
-
dequantize(module, param_name, param_value, target_device, dq_param_name, **shard_kwargs)
|
|
251
|
-
else:
|
|
252
|
-
load_and_swizzle_mxfp4(
|
|
253
|
-
module,
|
|
254
|
-
param_name,
|
|
255
|
-
param_value,
|
|
256
|
-
target_device,
|
|
257
|
-
self._lazy_import_kernels(),
|
|
258
|
-
**shard_kwargs,
|
|
259
|
-
)
|
|
260
|
-
|
|
261
160
|
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
|
262
|
-
# we are not really dequantizing, we are just removing everything related to quantization here
|
|
263
|
-
if self.quantization_config.dequantize:
|
|
264
|
-
self.remove_quantization_config(model)
|
|
265
161
|
# clean cache due to triton ops
|
|
266
162
|
if torch.cuda.is_available():
|
|
267
163
|
torch.cuda.empty_cache()
|
|
268
164
|
elif torch.xpu.is_available():
|
|
269
165
|
torch.xpu.empty_cache()
|
|
270
166
|
|
|
271
|
-
def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str], checkpoint_keys: list[str]):
|
|
272
|
-
# Replace expected_keys for experts' gate_up_proj and down_proj with their _blocks and _scales variants
|
|
273
|
-
new_expected_keys = []
|
|
274
|
-
for key in expected_keys:
|
|
275
|
-
if key.endswith(".mlp.experts.gate_up_proj"):
|
|
276
|
-
base = key[: -len("gate_up_proj")]
|
|
277
|
-
new_expected_keys.append(base + "gate_up_proj_blocks")
|
|
278
|
-
new_expected_keys.append(base + "gate_up_proj_scales")
|
|
279
|
-
elif key.endswith(".mlp.experts.down_proj"):
|
|
280
|
-
base = key[: -len("down_proj")]
|
|
281
|
-
new_expected_keys.append(base + "down_proj_blocks")
|
|
282
|
-
new_expected_keys.append(base + "down_proj_scales")
|
|
283
|
-
elif not self.pre_quantized:
|
|
284
|
-
# in this case, we are quantizing the model so we need to update the keys as we changed the layers
|
|
285
|
-
if key.endswith(".mlp.experts.down_proj_blocks"):
|
|
286
|
-
base = key[: -len("down_proj_blocks")]
|
|
287
|
-
new_expected_keys.append(base + "down_proj")
|
|
288
|
-
elif key.endswith(".mlp.experts.gate_up_proj_blocks"):
|
|
289
|
-
base = key[: -len("gate_up_proj_blocks")]
|
|
290
|
-
new_expected_keys.append(base + "gate_up_proj")
|
|
291
|
-
elif key.endswith("scales"):
|
|
292
|
-
# we remove it the scales as the checkpoint don't contain them
|
|
293
|
-
continue
|
|
294
|
-
else:
|
|
295
|
-
new_expected_keys.append(key)
|
|
296
|
-
else:
|
|
297
|
-
new_expected_keys.append(key)
|
|
298
|
-
return new_expected_keys
|
|
299
|
-
|
|
300
167
|
def _process_model_before_weight_loading(
|
|
301
168
|
self,
|
|
302
169
|
model: "PreTrainedModel",
|
|
303
170
|
keep_in_fp32_modules: list[str] | None = None,
|
|
171
|
+
use_kernels: bool = False,
|
|
304
172
|
**kwargs,
|
|
305
173
|
):
|
|
306
174
|
from ..integrations import replace_with_mxfp4_linear
|
|
307
175
|
|
|
308
|
-
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
309
|
-
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
use_kernels = kwargs.get("use_kernels", False)
|
|
313
176
|
# if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling
|
|
314
177
|
if use_kernels:
|
|
315
178
|
logger.warning_once(
|
|
@@ -318,30 +181,13 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|
|
318
181
|
)
|
|
319
182
|
self.quantization_config.dequantize = True
|
|
320
183
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
model,
|
|
324
|
-
modules_to_not_convert=self.modules_to_not_convert,
|
|
325
|
-
quantization_config=self.quantization_config,
|
|
326
|
-
config=config,
|
|
184
|
+
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
185
|
+
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
|
327
186
|
)
|
|
328
187
|
|
|
329
|
-
model
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
from ..integrations import Mxfp4GptOssExperts
|
|
333
|
-
|
|
334
|
-
not_missing_keys = []
|
|
335
|
-
for name, module in model.named_modules():
|
|
336
|
-
if isinstance(module, Mxfp4GptOssExperts):
|
|
337
|
-
for missing in missing_keys:
|
|
338
|
-
if (
|
|
339
|
-
(name in missing or name in f"{prefix}.{missing}")
|
|
340
|
-
and not missing.endswith(".weight")
|
|
341
|
-
and not missing.endswith(".bias")
|
|
342
|
-
):
|
|
343
|
-
not_missing_keys.append(missing)
|
|
344
|
-
return [k for k in missing_keys if k not in not_missing_keys]
|
|
188
|
+
model = replace_with_mxfp4_linear(
|
|
189
|
+
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
|
|
190
|
+
)
|
|
345
191
|
|
|
346
192
|
def update_tp_plan(self, config):
|
|
347
193
|
if "GptOssConfig" in config.__class__.__name__:
|
|
@@ -382,7 +228,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|
|
382
228
|
return param_name.replace("down_proj", "down_proj_blocks")
|
|
383
229
|
return param_name
|
|
384
230
|
|
|
385
|
-
def get_state_dict_and_metadata(self, model
|
|
231
|
+
def get_state_dict_and_metadata(self, model):
|
|
386
232
|
from ..integrations import Mxfp4GptOssExperts
|
|
387
233
|
|
|
388
234
|
state_dict = model.state_dict()
|
|
@@ -421,7 +267,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|
|
421
267
|
metadata = {}
|
|
422
268
|
return state_dict, metadata
|
|
423
269
|
|
|
424
|
-
def is_serializable(self
|
|
270
|
+
def is_serializable(self):
|
|
425
271
|
return True
|
|
426
272
|
|
|
427
273
|
@property
|