transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -40,23 +40,10 @@ class QuantoHfQuantizer(HfQuantizer):
|
|
|
40
40
|
Quantizer for the quanto library
|
|
41
41
|
"""
|
|
42
42
|
|
|
43
|
-
required_packages = ["quanto", "accelerate"]
|
|
44
|
-
requires_parameters_quantization = True
|
|
45
43
|
requires_calibration = False
|
|
46
44
|
|
|
47
45
|
def __init__(self, quantization_config: QuantoConfig, **kwargs):
|
|
48
46
|
super().__init__(quantization_config, **kwargs)
|
|
49
|
-
self.post_init()
|
|
50
|
-
|
|
51
|
-
def post_init(self):
|
|
52
|
-
r"""
|
|
53
|
-
Safety checker
|
|
54
|
-
"""
|
|
55
|
-
if self.quantization_config.activations is not None and not self.pre_quantized:
|
|
56
|
-
raise ValueError(
|
|
57
|
-
"We don't support quantizing the activations with transformers library."
|
|
58
|
-
"Use quanto library for more complex use cases such as activations quantization, calibration and quantization aware training."
|
|
59
|
-
)
|
|
60
47
|
|
|
61
48
|
def validate_environment(self, *args, **kwargs):
|
|
62
49
|
if not is_optimum_quanto_available():
|
|
@@ -67,42 +54,22 @@ class QuantoHfQuantizer(HfQuantizer):
|
|
|
67
54
|
raise ImportError(
|
|
68
55
|
"Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
|
|
69
56
|
)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
57
|
+
device_map = kwargs.get("device_map")
|
|
58
|
+
if isinstance(device_map, dict):
|
|
59
|
+
if len(device_map) > 1 and "cpu" in device_map.values() or "disk" in device_map.values():
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"You are attempting to load an model with a device_map that contains a CPU or disk device."
|
|
62
|
+
"This is not supported with quanto when the model is quantized on the fly. "
|
|
63
|
+
"Please remove the CPU or disk device from the device_map."
|
|
64
|
+
)
|
|
65
|
+
if self.quantization_config.activations is not None:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"We don't support quantizing the activations with transformers library."
|
|
68
|
+
"Use quanto library for more complex use cases such as activations quantization, calibration and quantization aware training."
|
|
78
69
|
)
|
|
79
|
-
return device_map
|
|
80
|
-
|
|
81
|
-
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
82
|
-
if dtype is None:
|
|
83
|
-
logger.info("You did not specify `dtype` in `from_pretrained`. Setting it to `torch.float32`.")
|
|
84
|
-
dtype = torch.float32
|
|
85
|
-
return dtype
|
|
86
|
-
|
|
87
|
-
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
|
|
88
|
-
if is_optimum_quanto_available():
|
|
89
|
-
from optimum.quanto import QModuleMixin
|
|
90
|
-
|
|
91
|
-
not_missing_keys = []
|
|
92
|
-
for name, module in model.named_modules():
|
|
93
|
-
if isinstance(module, QModuleMixin):
|
|
94
|
-
for missing in missing_keys:
|
|
95
|
-
if (
|
|
96
|
-
(name in missing or name in f"{prefix}.{missing}")
|
|
97
|
-
and not missing.endswith(".weight")
|
|
98
|
-
and not missing.endswith(".bias")
|
|
99
|
-
):
|
|
100
|
-
not_missing_keys.append(missing)
|
|
101
|
-
return [k for k in missing_keys if k not in not_missing_keys]
|
|
102
70
|
|
|
103
71
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
104
|
-
|
|
105
|
-
from optimum.quanto import QModuleMixin
|
|
72
|
+
from optimum.quanto import QModuleMixin
|
|
106
73
|
|
|
107
74
|
module, tensor_name = get_module_from_name(model, param_name)
|
|
108
75
|
# We only quantize the weights and the bias is not quantized.
|
|
@@ -116,21 +83,6 @@ class QuantoHfQuantizer(HfQuantizer):
|
|
|
116
83
|
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
|
117
84
|
return max_memory
|
|
118
85
|
|
|
119
|
-
def create_quantized_param(
|
|
120
|
-
self,
|
|
121
|
-
model: "PreTrainedModel",
|
|
122
|
-
param_value: "torch.Tensor",
|
|
123
|
-
param_name: str,
|
|
124
|
-
target_device: "torch.device",
|
|
125
|
-
**kwargs,
|
|
126
|
-
):
|
|
127
|
-
from ..modeling_utils import _load_parameter_into_model
|
|
128
|
-
|
|
129
|
-
_load_parameter_into_model(model, param_name, param_value.to(target_device))
|
|
130
|
-
module, _ = get_module_from_name(model, param_name)
|
|
131
|
-
module.freeze()
|
|
132
|
-
module.weight.requires_grad = False
|
|
133
|
-
|
|
134
86
|
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
|
135
87
|
from accelerate.utils import CustomDtype
|
|
136
88
|
|
|
@@ -152,14 +104,18 @@ class QuantoHfQuantizer(HfQuantizer):
|
|
|
152
104
|
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
|
153
105
|
)
|
|
154
106
|
|
|
155
|
-
model
|
|
107
|
+
model = replace_with_quanto_layers(
|
|
156
108
|
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
|
|
157
109
|
)
|
|
158
|
-
model.config.quantization_config = self.quantization_config
|
|
159
110
|
|
|
160
111
|
@property
|
|
161
112
|
def is_trainable(self) -> bool:
|
|
162
113
|
return True
|
|
163
114
|
|
|
164
|
-
def is_serializable(self
|
|
115
|
+
def is_serializable(self):
|
|
165
116
|
return False
|
|
117
|
+
|
|
118
|
+
def get_quantize_ops(self):
|
|
119
|
+
from ..integrations.quanto import QuantoQuantize
|
|
120
|
+
|
|
121
|
+
return QuantoQuantize(self)
|
|
@@ -45,12 +45,6 @@ class QuarkHfQuantizer(HfQuantizer):
|
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
47
|
requires_calibration = True # On-the-fly quantization with quark is not supported for now.
|
|
48
|
-
required_packages = ["quark"]
|
|
49
|
-
|
|
50
|
-
# Checkpoints are expected to be already quantized when loading a quark model. However, as some keys from
|
|
51
|
-
# the checkpoint might mismatch the model parameters keys, we use the `create_quantized_param` method
|
|
52
|
-
# to load the checkpoints, remapping the keys.
|
|
53
|
-
requires_parameters_quantization = True
|
|
54
48
|
|
|
55
49
|
def __init__(self, quantization_config, **kwargs):
|
|
56
50
|
super().__init__(quantization_config, **kwargs)
|
|
@@ -78,19 +72,44 @@ class QuarkHfQuantizer(HfQuantizer):
|
|
|
78
72
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
79
73
|
return True
|
|
80
74
|
|
|
81
|
-
def
|
|
82
|
-
from ..modeling_utils import _load_parameter_into_model
|
|
83
|
-
|
|
84
|
-
postfix = param_name.split(".")[-1]
|
|
85
|
-
|
|
86
|
-
if postfix in CHECKPOINT_KEYS:
|
|
87
|
-
param_name = param_name.replace(postfix, CHECKPOINT_KEYS[postfix])
|
|
88
|
-
|
|
89
|
-
_load_parameter_into_model(model, param_name, param.to(param_device))
|
|
90
|
-
|
|
91
|
-
def is_serializable(self, safe_serialization=None):
|
|
75
|
+
def is_serializable(self):
|
|
92
76
|
return False
|
|
93
77
|
|
|
94
78
|
@property
|
|
95
79
|
def is_trainable(self):
|
|
96
80
|
return False
|
|
81
|
+
|
|
82
|
+
def get_weight_conversions(self):
|
|
83
|
+
from ..core_model_loading import WeightConverter
|
|
84
|
+
from ..integrations.quark import QuarkDeserialize
|
|
85
|
+
# In Quark, quantization is managed through a QParamsLinear module, which holds
|
|
86
|
+
# separate quantizers for the weights, inputs, and biases (e.g. weight_quantizer
|
|
87
|
+
# input_quantizer, bias_quantizer, etc.).
|
|
88
|
+
#
|
|
89
|
+
# When you call `module.state_dict()`, Quark automatically renames the quantizer
|
|
90
|
+
# parameters — for example, `input_quantizer.scale` becomes `input_scale` — and
|
|
91
|
+
# saves them directly at the parent module level.
|
|
92
|
+
#
|
|
93
|
+
# This means we cannot simply rename keys like `weight_scale` back to
|
|
94
|
+
# `weight_quantizer.scale` when loading the state_dict.
|
|
95
|
+
# Otherwise, the `missing_keys` list would still expect keys such as
|
|
96
|
+
# `weight_scale`, `bias_scale`, etc.
|
|
97
|
+
#
|
|
98
|
+
# To fix this, we keep the expected state_dict keys (like `weight_scale`,
|
|
99
|
+
# `bias_scale`, etc.) unchanged, and during the conversion step, we explicitly
|
|
100
|
+
# assign their values into the corresponding quantizer attributes
|
|
101
|
+
# (`weight_quantizer.scale`, `input_quantizer.scale`, and so on).
|
|
102
|
+
|
|
103
|
+
# You can notice here that in target_patterns we use the same key as the source_patterns,
|
|
104
|
+
# this is because we just want to collect the tensors, and we will rename them later in the convert function.
|
|
105
|
+
# We cannot rename directly or else the missing_keys list will not be able to find the tensors.
|
|
106
|
+
converters = []
|
|
107
|
+
for key in CHECKPOINT_KEYS.keys():
|
|
108
|
+
converters.append(
|
|
109
|
+
WeightConverter(
|
|
110
|
+
source_patterns=[key],
|
|
111
|
+
target_patterns=key,
|
|
112
|
+
operations=[QuarkDeserialize(self)],
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
return converters
|
|
@@ -39,7 +39,6 @@ class SpQRHfQuantizer(HfQuantizer):
|
|
|
39
39
|
|
|
40
40
|
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
|
41
41
|
super().__init__(quantization_config, **kwargs)
|
|
42
|
-
self.quantization_config = quantization_config
|
|
43
42
|
|
|
44
43
|
def validate_environment(self, *args, **kwargs):
|
|
45
44
|
if not torch.cuda.is_available():
|
|
@@ -71,17 +70,15 @@ class SpQRHfQuantizer(HfQuantizer):
|
|
|
71
70
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
72
71
|
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
|
73
72
|
)
|
|
74
|
-
|
|
75
73
|
replace_with_spqr_linear(
|
|
76
74
|
model,
|
|
77
75
|
quantization_config=self.quantization_config,
|
|
78
76
|
modules_to_not_convert=self.modules_to_not_convert,
|
|
79
77
|
)
|
|
80
|
-
model.config.quantization_config = self.quantization_config
|
|
81
78
|
|
|
82
79
|
@property
|
|
83
80
|
def is_trainable(self):
|
|
84
81
|
return False
|
|
85
82
|
|
|
86
|
-
def is_serializable(self
|
|
83
|
+
def is_serializable(self):
|
|
87
84
|
return True
|
|
@@ -13,8 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import importlib
|
|
15
15
|
import re
|
|
16
|
-
import types
|
|
17
|
-
from collections import defaultdict
|
|
18
16
|
from typing import TYPE_CHECKING
|
|
19
17
|
|
|
20
18
|
from packaging import version
|
|
@@ -37,17 +35,12 @@ if is_torch_available():
|
|
|
37
35
|
|
|
38
36
|
if is_torch_available():
|
|
39
37
|
import torch
|
|
40
|
-
import torch.nn as nn
|
|
41
38
|
|
|
42
39
|
if is_torchao_available():
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
|
|
40
|
+
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
|
|
46
41
|
from torchao.prototype.safetensors.safetensors_support import (
|
|
47
42
|
flatten_tensor_state_dict,
|
|
48
|
-
unflatten_tensor_state_dict,
|
|
49
43
|
)
|
|
50
|
-
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
|
|
51
44
|
|
|
52
45
|
|
|
53
46
|
logger = logging.get_logger(__name__)
|
|
@@ -88,11 +81,6 @@ def _linear_extra_repr(self):
|
|
|
88
81
|
|
|
89
82
|
|
|
90
83
|
if is_torchao_available():
|
|
91
|
-
SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
|
|
92
|
-
torchao.quantization.Float8WeightOnlyConfig,
|
|
93
|
-
torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
|
|
94
|
-
]
|
|
95
|
-
|
|
96
84
|
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
|
|
97
85
|
|
|
98
86
|
|
|
@@ -101,9 +89,7 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|
|
101
89
|
Quantizer for torchao: https://github.com/pytorch/ao/
|
|
102
90
|
"""
|
|
103
91
|
|
|
104
|
-
requires_parameters_quantization = True
|
|
105
92
|
requires_calibration = False
|
|
106
|
-
required_packages = ["torchao"]
|
|
107
93
|
|
|
108
94
|
def __init__(self, quantization_config, **kwargs):
|
|
109
95
|
super().__init__(quantization_config, **kwargs)
|
|
@@ -166,20 +152,16 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|
|
166
152
|
dtype = torch.float32
|
|
167
153
|
return dtype
|
|
168
154
|
|
|
169
|
-
def get_state_dict_and_metadata(self, model
|
|
155
|
+
def get_state_dict_and_metadata(self, model):
|
|
170
156
|
"""
|
|
171
|
-
|
|
172
|
-
the safetensors format.
|
|
157
|
+
We flatten the state dict of tensor subclasses so that it is compatible with the safetensors format.
|
|
173
158
|
"""
|
|
174
|
-
if
|
|
175
|
-
|
|
176
|
-
return flatten_tensor_state_dict(model.state_dict())
|
|
177
|
-
else:
|
|
178
|
-
raise RuntimeError(
|
|
179
|
-
f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}"
|
|
180
|
-
)
|
|
159
|
+
if TORCHAO_VERSION >= version.parse("0.15.0"):
|
|
160
|
+
return flatten_tensor_state_dict(model.state_dict()), {}
|
|
181
161
|
else:
|
|
182
|
-
|
|
162
|
+
raise RuntimeError(
|
|
163
|
+
f"In order to use safetensors with torchao, please use torchao version >= 0.15.0. Current version: {TORCHAO_VERSION}"
|
|
164
|
+
)
|
|
183
165
|
|
|
184
166
|
def adjust_target_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
185
167
|
from accelerate.utils import CustomDtype
|
|
@@ -237,9 +219,6 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|
|
237
219
|
]
|
|
238
220
|
return
|
|
239
221
|
|
|
240
|
-
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
|
|
241
|
-
return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)]
|
|
242
|
-
|
|
243
222
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
244
223
|
if self.pre_quantized:
|
|
245
224
|
return False
|
|
@@ -249,8 +228,6 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|
|
249
228
|
# check if the param_name is not in self.modules_to_not_convert
|
|
250
229
|
if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert):
|
|
251
230
|
return False
|
|
252
|
-
elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys):
|
|
253
|
-
return True
|
|
254
231
|
|
|
255
232
|
# we only quantize the weight of nn.Linear and nn.Embedding
|
|
256
233
|
module, tensor_name = get_module_from_name(model, param_name)
|
|
@@ -276,148 +253,6 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|
|
276
253
|
|
|
277
254
|
return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight"
|
|
278
255
|
|
|
279
|
-
def create_quantized_param(
|
|
280
|
-
self,
|
|
281
|
-
model: "PreTrainedModel",
|
|
282
|
-
param_value: "torch.Tensor",
|
|
283
|
-
param_name: str,
|
|
284
|
-
target_device: "torch.device",
|
|
285
|
-
**kwargs,
|
|
286
|
-
):
|
|
287
|
-
"""
|
|
288
|
-
Each nn.Linear layer that needs to be quantized is processed here.
|
|
289
|
-
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
|
|
290
|
-
"""
|
|
291
|
-
from torchao.quantization import quantize_
|
|
292
|
-
|
|
293
|
-
full_name = param_name
|
|
294
|
-
# Those are the pre quantized weights
|
|
295
|
-
if ":" in param_name:
|
|
296
|
-
param_name = param_name.rsplit(":", 1)[0]
|
|
297
|
-
module, tensor_name = get_module_from_name(model, param_name)
|
|
298
|
-
|
|
299
|
-
if self.pre_quantized:
|
|
300
|
-
# If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
|
|
301
|
-
# already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
|
|
302
|
-
is_unsafe_serialization = ":" not in full_name
|
|
303
|
-
if tensor_name == "bias" or is_unsafe_serialization:
|
|
304
|
-
module._parameters[tensor_name] = torch.nn.Parameter(
|
|
305
|
-
param_value.to(target_device), requires_grad=param_value.requires_grad
|
|
306
|
-
)
|
|
307
|
-
return
|
|
308
|
-
# Sanity check for the new serialization format
|
|
309
|
-
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)):
|
|
310
|
-
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
|
|
311
|
-
|
|
312
|
-
# Save the states for later quantization when they are all gathered
|
|
313
|
-
if not hasattr(self, "ao_params"):
|
|
314
|
-
self.ao_params = defaultdict(dict)
|
|
315
|
-
self.ao_params[param_name].update({full_name: param_value})
|
|
316
|
-
|
|
317
|
-
# We are ready for quantization in this case (we retrieved all the needed keys)
|
|
318
|
-
if len(self.ao_params[param_name]) == len(self.weight_ao_keys):
|
|
319
|
-
new_param = unflatten_tensor_state_dict(self.ao_params[param_name], self.metadata)[param_name]
|
|
320
|
-
# Set it
|
|
321
|
-
module._parameters[tensor_name] = torch.nn.Parameter(
|
|
322
|
-
new_param.to(target_device), requires_grad=new_param.requires_grad
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
# Free memory
|
|
326
|
-
del self.ao_params[param_name]
|
|
327
|
-
|
|
328
|
-
# Add repr to the module
|
|
329
|
-
if isinstance(module, nn.Linear):
|
|
330
|
-
module.extra_repr = types.MethodType(_linear_extra_repr, module)
|
|
331
|
-
else:
|
|
332
|
-
module._parameters[tensor_name] = torch.nn.Parameter(
|
|
333
|
-
param_value, requires_grad=param_value.requires_grad
|
|
334
|
-
).to(target_device)
|
|
335
|
-
# if we are quantizing tied parameters, to avoid tying the quantized weights
|
|
336
|
-
# the correct order to do it is
|
|
337
|
-
# 1. load the weight to model
|
|
338
|
-
# 2. run tie_weights to populate the weights
|
|
339
|
-
# 3. quantize
|
|
340
|
-
input_embed = model.get_input_embeddings()
|
|
341
|
-
if self.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
|
|
342
|
-
model.tie_weights()
|
|
343
|
-
setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)
|
|
344
|
-
|
|
345
|
-
# handle FqnToConfig, introduced in torchao 0.15.0+
|
|
346
|
-
if self.quantization_config._get_ao_version() >= version.Version("0.15.0"):
|
|
347
|
-
from torchao.quantization import FqnToConfig
|
|
348
|
-
|
|
349
|
-
config = self.quantization_config.get_apply_tensor_subclass()
|
|
350
|
-
if isinstance(config, FqnToConfig):
|
|
351
|
-
module_fqn, top_level_param_name = param_name.rsplit(".", 1)
|
|
352
|
-
c = None
|
|
353
|
-
if param_name in config.fqn_to_config:
|
|
354
|
-
assert not module_fqn.startswith("re:"), (
|
|
355
|
-
"param fqn should not start with`re:`, which is used for specifying regex"
|
|
356
|
-
)
|
|
357
|
-
c = config.module_fqn_to_config[param_name]
|
|
358
|
-
elif module_fqn in config.fqn_to_config:
|
|
359
|
-
assert not module_fqn.startswith("re:"), (
|
|
360
|
-
"module fqn should not start with`re:`, which is used for specifying regex"
|
|
361
|
-
)
|
|
362
|
-
c = config.module_fqn_to_config[module_fqn]
|
|
363
|
-
# regex match module and param
|
|
364
|
-
else:
|
|
365
|
-
for maybe_module_fqn_pattern in config.fqn_to_config:
|
|
366
|
-
# if key doesn't start with re, it is an exact fqn key, so we don't regex match
|
|
367
|
-
if not maybe_module_fqn_pattern.startswith("re:"):
|
|
368
|
-
continue
|
|
369
|
-
# see if param matches first
|
|
370
|
-
elif re.fullmatch(maybe_module_fqn_pattern[3:], param_name):
|
|
371
|
-
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
|
|
372
|
-
break
|
|
373
|
-
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
|
|
374
|
-
# we'll apply the config for first fully matched pattern
|
|
375
|
-
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
|
|
376
|
-
break
|
|
377
|
-
else:
|
|
378
|
-
c = config.module_fqn_to_config.get("_default", None)
|
|
379
|
-
|
|
380
|
-
if c is not None:
|
|
381
|
-
if top_level_param_name == "weight":
|
|
382
|
-
# we can apply the module config directly
|
|
383
|
-
quantize_(module, c, (lambda x, fqn: True))
|
|
384
|
-
else:
|
|
385
|
-
# need to apply to custom param name
|
|
386
|
-
custom_param_fqn_config = FqnToConfig({top_level_param_name: c})
|
|
387
|
-
quantize_(module, custom_param_fqn_config, filter_fn=None)
|
|
388
|
-
return
|
|
389
|
-
|
|
390
|
-
# handle ModuleFqnToConfig, introduced in torchao 0.12.0+
|
|
391
|
-
# TODO deprecate this when we deprecate ModuleFqnToConfig
|
|
392
|
-
elif self.quantization_config._get_ao_version() >= version.Version("0.12.0"):
|
|
393
|
-
from torchao.quantization import ModuleFqnToConfig
|
|
394
|
-
|
|
395
|
-
config = self.quantization_config.get_apply_tensor_subclass()
|
|
396
|
-
if isinstance(config, ModuleFqnToConfig):
|
|
397
|
-
module_fqn, _ = param_name.rsplit(".", 1)
|
|
398
|
-
c = None
|
|
399
|
-
if module_fqn in config.module_fqn_to_config:
|
|
400
|
-
assert not module_fqn.startswith("re:"), (
|
|
401
|
-
"module fqn should not start with`re:`, which is used for specifying regex"
|
|
402
|
-
)
|
|
403
|
-
c = config.module_fqn_to_config[module_fqn]
|
|
404
|
-
else:
|
|
405
|
-
for maybe_module_fqn_pattern in config.module_fqn_to_config:
|
|
406
|
-
if not maybe_module_fqn_pattern.startswith("re:"):
|
|
407
|
-
continue
|
|
408
|
-
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
|
|
409
|
-
# we'll apply the config for first fully matched pattern
|
|
410
|
-
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
|
|
411
|
-
break
|
|
412
|
-
else:
|
|
413
|
-
c = config.module_fqn_to_config.get("_default", None)
|
|
414
|
-
if c is not None:
|
|
415
|
-
# filter_fn: not filtering out any modules
|
|
416
|
-
quantize_(module, c, filter_fn=lambda x, fqn: True)
|
|
417
|
-
return
|
|
418
|
-
|
|
419
|
-
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
|
420
|
-
|
|
421
256
|
def preprocess_model(self, model: "PreTrainedModel", config, dtype=None, checkpoint_files=None, **kwargs):
|
|
422
257
|
"""
|
|
423
258
|
Setting model attributes and/or converting model before weights loading. At this point
|
|
@@ -450,30 +285,13 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|
|
450
285
|
return model
|
|
451
286
|
return
|
|
452
287
|
|
|
453
|
-
def is_serializable(self
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
self.quantization_config.quant_type
|
|
457
|
-
) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0")
|
|
458
|
-
if not _is_torchao_serializable:
|
|
459
|
-
logger.warning(
|
|
460
|
-
f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \
|
|
461
|
-
and torchao version >= 0.14.0, please set `safe_serialization` to False for \
|
|
462
|
-
{type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}."
|
|
463
|
-
)
|
|
464
|
-
return _is_torchao_serializable
|
|
465
|
-
|
|
466
|
-
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
|
|
467
|
-
"0.25.0"
|
|
468
|
-
)
|
|
469
|
-
if not _is_torchao_serializable:
|
|
470
|
-
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
|
|
471
|
-
if self.offload and self.quantization_config.modules_to_not_convert is None:
|
|
288
|
+
def is_serializable(self) -> bool:
|
|
289
|
+
_is_torchao_serializable = TORCHAO_VERSION >= version.parse("0.15.0")
|
|
290
|
+
if not TORCHAO_VERSION >= version.parse("0.15.0"):
|
|
472
291
|
logger.warning(
|
|
473
|
-
"
|
|
474
|
-
"
|
|
292
|
+
"torchao quantized model only supports serialization for torchao version >= 0.15.0, please upgrade "
|
|
293
|
+
"your version to save the quantized model"
|
|
475
294
|
)
|
|
476
|
-
return False
|
|
477
295
|
return _is_torchao_serializable
|
|
478
296
|
|
|
479
297
|
def get_accelerator_warm_up_factor(self):
|
|
@@ -548,15 +366,18 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|
|
548
366
|
if self.pre_quantized:
|
|
549
367
|
return [
|
|
550
368
|
WeightConverter(
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
369
|
+
# TODO: incr flexibility by generalizing the source patterns to match the format of "_weight_"
|
|
370
|
+
# note that the matching logic is greedy, so for ex, if _weight_scale is before _weight_scale_and_zero in this list, it will match _weight_scale always (this is incorrect)
|
|
371
|
+
# thus, the order of source_patterns is intentional
|
|
372
|
+
source_patterns=[
|
|
373
|
+
"_weight_qdata",
|
|
374
|
+
"_weight_scale_and_zero",
|
|
375
|
+
"_weight_scale",
|
|
376
|
+
"_weight_zero_point",
|
|
377
|
+
"_weight_act_pre_scale",
|
|
378
|
+
],
|
|
557
379
|
target_patterns="weight",
|
|
558
380
|
operations=[TorchAoDeserialize(self)],
|
|
559
381
|
),
|
|
560
|
-
# used for unsafe serialization
|
|
561
382
|
]
|
|
562
383
|
return []
|
|
@@ -35,11 +35,9 @@ class VptqHfQuantizer(HfQuantizer):
|
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
37
|
requires_calibration = True
|
|
38
|
-
required_packages = ["vptq"]
|
|
39
38
|
|
|
40
39
|
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
|
41
40
|
super().__init__(quantization_config, **kwargs)
|
|
42
|
-
self.quantization_config = quantization_config
|
|
43
41
|
|
|
44
42
|
def validate_environment(self, *args, **kwargs):
|
|
45
43
|
if not is_accelerate_available():
|
|
@@ -48,21 +46,15 @@ class VptqHfQuantizer(HfQuantizer):
|
|
|
48
46
|
if not is_vptq_available():
|
|
49
47
|
raise ImportError("Using `vptq` quantization requires VPTQ>=0.0.4: `pip install -U vptq`")
|
|
50
48
|
|
|
49
|
+
if not torch.cuda.is_available():
|
|
50
|
+
raise RuntimeError("GPU is required to run VTPQ quantized model.")
|
|
51
|
+
|
|
51
52
|
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
|
|
52
53
|
if dtype is None:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
)
|
|
58
|
-
else:
|
|
59
|
-
import vptq
|
|
60
|
-
|
|
61
|
-
device_availability = getattr(vptq, "device_availability", lambda device: False)
|
|
62
|
-
if device_availability("cpu") is True:
|
|
63
|
-
raise RuntimeError("No GPU found. Please wait for the next release of VPTQ to use CPU inference")
|
|
64
|
-
dtype = torch.float32
|
|
65
|
-
logger.info("No GPU found. Assuming VPTQ inference on CPU and loading the model in `torch.float32`.")
|
|
54
|
+
dtype = torch.float16
|
|
55
|
+
logger.info(
|
|
56
|
+
"Assuming VPTQ inference on GPU and loading the model in `torch.float16`. To overwrite it, set `dtype` manually."
|
|
57
|
+
)
|
|
66
58
|
return dtype
|
|
67
59
|
|
|
68
60
|
def _process_model_before_weight_loading(
|
|
@@ -71,26 +63,20 @@ class VptqHfQuantizer(HfQuantizer):
|
|
|
71
63
|
keep_in_fp32_modules: list[str] | None = None,
|
|
72
64
|
**kwargs,
|
|
73
65
|
):
|
|
74
|
-
"""
|
|
75
|
-
we don't have param like modules_to_not_convert to indicate which layers should not be quantized
|
|
76
|
-
because `quantization_config` include the layers that should be quantized
|
|
77
|
-
"""
|
|
78
66
|
from ..integrations import replace_with_vptq_linear
|
|
79
67
|
|
|
80
68
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
81
69
|
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
|
82
70
|
)
|
|
83
|
-
|
|
84
71
|
replace_with_vptq_linear(
|
|
85
72
|
model,
|
|
86
73
|
quantization_config=self.quantization_config,
|
|
87
74
|
modules_to_not_convert=self.modules_to_not_convert,
|
|
88
75
|
)
|
|
89
|
-
model.config.quantization_config = self.quantization_config
|
|
90
76
|
|
|
91
77
|
@property
|
|
92
78
|
def is_trainable(self) -> bool:
|
|
93
79
|
return False
|
|
94
80
|
|
|
95
|
-
def is_serializable(self
|
|
81
|
+
def is_serializable(self):
|
|
96
82
|
return True
|
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import re
|
|
14
15
|
from typing import Any
|
|
15
16
|
|
|
16
17
|
|
|
@@ -19,3 +20,22 @@ def get_module_from_name(module, tensor_name: str) -> tuple[Any, str]:
|
|
|
19
20
|
module_name, tensor_name = tensor_name.rsplit(".", 1)
|
|
20
21
|
module = module.get_submodule(module_name)
|
|
21
22
|
return module, tensor_name
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def should_convert_module(full_name, patterns: list[str] | None = None):
|
|
26
|
+
if patterns is None:
|
|
27
|
+
return True
|
|
28
|
+
|
|
29
|
+
# We should avoid converting in the following situations:
|
|
30
|
+
# 1. The pattern appears as a prefix followed by a dot in `full_name`
|
|
31
|
+
# (e.g., "model.decoder.layer.11." matches "model.decoder.layer.11.attn.weight").
|
|
32
|
+
# 2. The pattern matches `full_name` exactly or via regex
|
|
33
|
+
# (e.g., "lm_head" matches "lm_head"; "model.decoder.layer.*" matches "model.decoder.layer.11.attn.weight").
|
|
34
|
+
# 3. `full_name` ends with the pattern
|
|
35
|
+
# (e.g., "fc1" matches "model.decoder.layers.23.fc1").
|
|
36
|
+
|
|
37
|
+
should_not_convert = any(
|
|
38
|
+
re.match(f"{key}\\.", full_name) or re.match(f"{key}", full_name) or full_name.endswith(key)
|
|
39
|
+
for key in patterns
|
|
40
|
+
)
|
|
41
|
+
return not should_not_convert
|