transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -12,11 +12,13 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
from ..
|
|
15
|
+
from ..core_model_loading import ConversionOps
|
|
16
|
+
from ..quantizers.quantizers_utils import should_convert_module
|
|
17
|
+
from ..utils import is_accelerate_available, is_torch_available, logging
|
|
16
18
|
|
|
17
19
|
|
|
18
|
-
if
|
|
19
|
-
import
|
|
20
|
+
if is_torch_available():
|
|
21
|
+
import torch
|
|
20
22
|
import torch.nn as nn
|
|
21
23
|
|
|
22
24
|
if is_accelerate_available():
|
|
@@ -25,91 +27,94 @@ if is_accelerate_available():
|
|
|
25
27
|
logger = logging.get_logger(__name__)
|
|
26
28
|
|
|
27
29
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
current_key_name=None,
|
|
32
|
-
quantization_config=None,
|
|
33
|
-
has_been_replaced=False,
|
|
34
|
-
pre_quantized=False,
|
|
35
|
-
):
|
|
36
|
-
"""
|
|
37
|
-
Private method that wraps the recursion for module replacement.
|
|
30
|
+
class EetqQuantize(ConversionOps):
|
|
31
|
+
def __init__(self, hf_quantizer):
|
|
32
|
+
self.hf_quantizer = hf_quantizer
|
|
38
33
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
has_been_replaced=has_been_replaced,
|
|
72
|
-
pre_quantized=pre_quantized,
|
|
73
|
-
)
|
|
74
|
-
# Remove the last key for recursion
|
|
75
|
-
current_key_name.pop(-1)
|
|
76
|
-
return model, has_been_replaced
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def replace_with_eetq_linear(
|
|
80
|
-
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
|
|
81
|
-
):
|
|
82
|
-
"""
|
|
83
|
-
A helper function to replace all `torch.nn.Linear` modules by `eetq.EetqLinear` modules from the `eetq`
|
|
84
|
-
library. This will enable running your models using high performance int8 weight-only gemm kerner from
|
|
85
|
-
FasterTransformer and TensorRT-LLM. Make sure `eetq` compiled with the correct CUDA
|
|
86
|
-
version of your hardware is installed before running this function. EETQ shall be installed via the source
|
|
87
|
-
'https://github.com/NetEase-FuXi/EETQ'
|
|
34
|
+
def convert(
|
|
35
|
+
self, input_dict: dict[str, list[torch.Tensor]], full_layer_name: str | None = None, **kwargs
|
|
36
|
+
) -> dict[str, torch.Tensor]:
|
|
37
|
+
_, value = tuple(input_dict.items())[0]
|
|
38
|
+
value = value[0]
|
|
39
|
+
|
|
40
|
+
value_device = value.device
|
|
41
|
+
int8_weight = torch.t(value).contiguous().cpu()
|
|
42
|
+
int8_weight, scales = eetq_kernels_hub.quant_weights(int8_weight, torch.int8, False)
|
|
43
|
+
|
|
44
|
+
int8_weight = int8_weight.to(value_device)
|
|
45
|
+
scales = scales.to(value_device)
|
|
46
|
+
|
|
47
|
+
return {full_layer_name: int8_weight, f"{full_layer_name}_scales": scales}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class EetqLinearMMFunction(torch.autograd.Function):
|
|
51
|
+
@staticmethod
|
|
52
|
+
def forward(ctx, x, weight, scales, bias=None):
|
|
53
|
+
# The forward pass can use ctx.
|
|
54
|
+
ctx.save_for_backward(x, weight, scales, bias)
|
|
55
|
+
output = eetq_kernels_hub.w8_a16_gemm(x, weight, scales)
|
|
56
|
+
output = output + bias if bias is not None else output
|
|
57
|
+
return output
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def backward(ctx, grad_output):
|
|
61
|
+
input, weight, scales, bias = ctx.saved_tensors
|
|
62
|
+
identity = torch.eye(weight.shape[0]).to(weight.device).to(input.dtype)
|
|
63
|
+
|
|
64
|
+
# Dequantize the weight
|
|
65
|
+
weight = eetq_kernels_hub.w8_a16_gemm(identity, weight, scales)
|
|
88
66
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
67
|
+
if ctx.needs_input_grad[0]:
|
|
68
|
+
# 2D matrix multiplication, unsqueeze to 3D
|
|
69
|
+
grad_input = grad_output.squeeze(0).matmul(weight.transpose(0, 1)).unsqueeze(0)
|
|
70
|
+
|
|
71
|
+
return grad_input, None, None, None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class EetqLinear(nn.Module):
|
|
75
|
+
def __init__(self, in_features, out_features, dtype=torch.int8, bias=False):
|
|
76
|
+
super().__init__()
|
|
77
|
+
self.weight = nn.Parameter(torch.empty((in_features, out_features), dtype=dtype), requires_grad=False)
|
|
78
|
+
self.weight_scales = nn.Parameter(torch.empty((out_features), dtype=torch.float16))
|
|
79
|
+
if bias:
|
|
80
|
+
self.bias = nn.Parameter(torch.empty((out_features), dtype=torch.float16))
|
|
81
|
+
else:
|
|
82
|
+
self.bias = None
|
|
83
|
+
|
|
84
|
+
def forward(self, input):
|
|
85
|
+
output = EetqLinearMMFunction.apply(input, self.weight, self.weight_scales, self.bias)
|
|
86
|
+
return output
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def replace_with_eetq_linear(model, modules_to_not_convert: list[str] | None = None, pre_quantized=False):
|
|
90
|
+
"""
|
|
91
|
+
A helper function to replace all `torch.nn.Linear` modules by `EetqLinear` modules.
|
|
92
92
|
|
|
93
93
|
Parameters:
|
|
94
94
|
model (`torch.nn.Module`):
|
|
95
95
|
Input model or `torch.nn.Module` as the function is run recursively.
|
|
96
|
-
modules_to_not_convert (`list[`str`]`, *optional*, defaults to `
|
|
96
|
+
modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
|
|
97
97
|
Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
|
|
98
98
|
for numerical stability reasons.
|
|
99
|
-
current_key_name (`list[`str`]`, *optional*):
|
|
100
|
-
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
|
101
|
-
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
|
|
102
|
-
`disk`).
|
|
103
99
|
"""
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
)
|
|
100
|
+
from kernels import get_kernel
|
|
101
|
+
|
|
102
|
+
global eetq_kernels_hub
|
|
103
|
+
eetq_kernels_hub = get_kernel("kernels-community/quantization-eetq")
|
|
104
|
+
|
|
105
|
+
has_been_replaced = False
|
|
106
|
+
# we need this to correctly materialize the weights during quantization
|
|
107
|
+
module_kwargs = {} if pre_quantized else {"dtype": None}
|
|
108
|
+
for module_name, module in model.named_modules():
|
|
109
|
+
if not should_convert_module(module_name, modules_to_not_convert):
|
|
110
|
+
continue
|
|
111
|
+
with init_empty_weights():
|
|
112
|
+
if isinstance(module, nn.Linear):
|
|
113
|
+
new_module = EetqLinear(
|
|
114
|
+
module.in_features, module.out_features, bias=module.bias is not None, **module_kwargs
|
|
115
|
+
)
|
|
116
|
+
model.set_submodule(module_name, new_module)
|
|
117
|
+
has_been_replaced = True
|
|
113
118
|
|
|
114
119
|
if not has_been_replaced:
|
|
115
120
|
logger.warning(
|
|
@@ -12,8 +12,19 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from functools import lru_cache
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
15
18
|
from ..activations import ACT2FN
|
|
16
|
-
from ..
|
|
19
|
+
from ..core_model_loading import ConversionOps
|
|
20
|
+
from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
|
|
21
|
+
from ..utils import (
|
|
22
|
+
is_accelerate_available,
|
|
23
|
+
is_fbgemm_gpu_available,
|
|
24
|
+
is_torch_available,
|
|
25
|
+
is_torch_xpu_available,
|
|
26
|
+
logging,
|
|
27
|
+
)
|
|
17
28
|
|
|
18
29
|
|
|
19
30
|
if is_torch_available():
|
|
@@ -23,24 +34,83 @@ if is_torch_available():
|
|
|
23
34
|
if is_accelerate_available():
|
|
24
35
|
from accelerate import init_empty_weights
|
|
25
36
|
|
|
26
|
-
|
|
37
|
+
_is_torch_xpu_available = is_torch_xpu_available()
|
|
38
|
+
|
|
39
|
+
if is_fbgemm_gpu_available() and not _is_torch_xpu_available:
|
|
27
40
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
|
28
41
|
|
|
29
42
|
logger = logging.get_logger(__name__)
|
|
30
43
|
|
|
31
44
|
|
|
45
|
+
class FbgemmFp8Quantize(ConversionOps):
|
|
46
|
+
def __init__(self, hf_quantizer):
|
|
47
|
+
self.hf_quantizer = hf_quantizer
|
|
48
|
+
|
|
49
|
+
def convert(
|
|
50
|
+
self,
|
|
51
|
+
input_dict: dict[str, torch.Tensor | list[torch.Tensor]],
|
|
52
|
+
model: Optional[torch.nn.Module] = None,
|
|
53
|
+
**kwargs,
|
|
54
|
+
) -> dict[str, torch.Tensor]:
|
|
55
|
+
target_key, value = tuple(input_dict.items())[0]
|
|
56
|
+
value = value[0]
|
|
57
|
+
|
|
58
|
+
from ..integrations import FbgemmFp8Llama4TextExperts
|
|
59
|
+
|
|
60
|
+
module, tensor_name = get_module_from_name(model, target_key)
|
|
61
|
+
|
|
62
|
+
if isinstance(module, FbgemmFp8Llama4TextExperts):
|
|
63
|
+
if tensor_name == "gate_up_proj":
|
|
64
|
+
# Process each expert separately
|
|
65
|
+
# Transpose the second and third dimension
|
|
66
|
+
transposed_param = value.transpose(1, 2)
|
|
67
|
+
|
|
68
|
+
# Reshape to 2D for quantization
|
|
69
|
+
original_shape = transposed_param.shape
|
|
70
|
+
flattened_param = transposed_param.reshape(-1, original_shape[-1])
|
|
71
|
+
|
|
72
|
+
# Quantize using per row instead of per column
|
|
73
|
+
new_value_flat, weight_scale_flat = quantize_fp8_per_row(flattened_param)
|
|
74
|
+
|
|
75
|
+
# Reshape back to original dimensions
|
|
76
|
+
new_value = new_value_flat.reshape(original_shape)
|
|
77
|
+
new_value = new_value.transpose(1, 2)
|
|
78
|
+
weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
|
|
79
|
+
elif tensor_name == "down_proj":
|
|
80
|
+
# Process each expert separately
|
|
81
|
+
# Transpose the weights for proper quantization
|
|
82
|
+
transposed_param = value.transpose(1, 2)
|
|
83
|
+
|
|
84
|
+
# Reshape to 2D for quantization
|
|
85
|
+
original_shape = transposed_param.shape
|
|
86
|
+
flattened_param = transposed_param.reshape(-1, original_shape[-1])
|
|
87
|
+
|
|
88
|
+
# Quantize using per column
|
|
89
|
+
new_value_flat, weight_scale_flat = quantize_fp8_per_row(flattened_param)
|
|
90
|
+
|
|
91
|
+
# Reshape back to original dimensions
|
|
92
|
+
new_value = new_value_flat.reshape(original_shape)
|
|
93
|
+
new_value = new_value.transpose(1, 2)
|
|
94
|
+
weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
|
|
95
|
+
else:
|
|
96
|
+
new_value, weight_scale = quantize_fp8_per_row(value)
|
|
97
|
+
weight_scale = torch.nn.Parameter(weight_scale.view(weight_scale.shape[0], 1))
|
|
98
|
+
|
|
99
|
+
return {target_key: torch.nn.Parameter(new_value), f"{target_key}_scale": weight_scale}
|
|
100
|
+
|
|
101
|
+
|
|
32
102
|
class FbgemmFp8Linear(torch.nn.Linear):
|
|
33
|
-
def __init__(self, in_features, out_features, bias,
|
|
103
|
+
def __init__(self, in_features, out_features, bias, dtype=torch.float8_e4m3fn):
|
|
34
104
|
super().__init__(in_features, out_features, bias)
|
|
35
105
|
self.in_features = in_features
|
|
36
106
|
self.out_features = out_features
|
|
37
107
|
|
|
38
|
-
self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=
|
|
39
|
-
self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=
|
|
108
|
+
self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=dtype))
|
|
109
|
+
self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=torch.float32))
|
|
40
110
|
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
|
|
41
111
|
|
|
42
112
|
if bias:
|
|
43
|
-
self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=
|
|
113
|
+
self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=torch.float32))
|
|
44
114
|
else:
|
|
45
115
|
self.bias = None
|
|
46
116
|
|
|
@@ -49,18 +119,26 @@ class FbgemmFp8Linear(torch.nn.Linear):
|
|
|
49
119
|
output_shape = (*x.shape[:-1], -1)
|
|
50
120
|
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
|
|
51
121
|
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
|
|
52
|
-
x_quantized, x_scale =
|
|
53
|
-
x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub
|
|
54
|
-
)
|
|
122
|
+
x_quantized, x_scale = quantize_fp8_per_row(x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub)
|
|
55
123
|
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
|
|
56
124
|
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
|
|
57
125
|
|
|
58
126
|
# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
|
|
59
127
|
weight_scale_float32 = self.weight_scale.to(torch.float32)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
128
|
+
if _is_torch_xpu_available:
|
|
129
|
+
output = torch._scaled_mm(
|
|
130
|
+
x_quantized,
|
|
131
|
+
self.weight.t(),
|
|
132
|
+
scale_a=x_scale.unsqueeze(-1),
|
|
133
|
+
scale_b=weight_scale_float32.t(),
|
|
134
|
+
out_dtype=x.dtype,
|
|
135
|
+
bias=self.bias,
|
|
136
|
+
)
|
|
137
|
+
else:
|
|
138
|
+
output = torch.ops.fbgemm.f8f8bf16_rowwise(
|
|
139
|
+
x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
|
|
140
|
+
)
|
|
141
|
+
output = output + self.bias if self.bias is not None else output
|
|
64
142
|
# Hacky for now, we have the output to the device of x
|
|
65
143
|
output = output.to(x.device)
|
|
66
144
|
output = output.reshape(output_shape)
|
|
@@ -112,168 +190,136 @@ class FbgemmFp8Llama4TextExperts(nn.Module):
|
|
|
112
190
|
expert_hidden = hidden_states[i]
|
|
113
191
|
expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size)
|
|
114
192
|
# Quantize for this expert
|
|
115
|
-
expert_quantized, expert_scale =
|
|
193
|
+
expert_quantized, expert_scale = quantize_fp8_per_row(
|
|
116
194
|
expert_hidden_reshaped, num_tokens, self.input_scale_ub
|
|
117
195
|
)
|
|
118
196
|
sharded_expert_dim = self.gate_up_proj.shape[-1] // 2
|
|
119
197
|
gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32)
|
|
198
|
+
if _is_torch_xpu_available:
|
|
199
|
+
gate = torch._scaled_mm(
|
|
200
|
+
expert_quantized,
|
|
201
|
+
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous().t(),
|
|
202
|
+
scale_a=expert_scale.unsqueeze(-1),
|
|
203
|
+
scale_b=gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous().t(),
|
|
204
|
+
out_dtype=hidden_states.dtype,
|
|
205
|
+
)
|
|
206
|
+
up = torch._scaled_mm(
|
|
207
|
+
expert_quantized,
|
|
208
|
+
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous().t(),
|
|
209
|
+
scale_a=expert_scale.unsqueeze(-1),
|
|
210
|
+
scale_b=gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous().t(),
|
|
211
|
+
out_dtype=hidden_states.dtype,
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
gate = torch.ops.fbgemm.f8f8bf16_rowwise(
|
|
215
|
+
expert_quantized,
|
|
216
|
+
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
|
|
217
|
+
expert_scale,
|
|
218
|
+
gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
|
|
219
|
+
use_fast_accum=True,
|
|
220
|
+
)
|
|
120
221
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
up = torch.ops.fbgemm.f8f8bf16_rowwise(
|
|
130
|
-
expert_quantized,
|
|
131
|
-
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
|
|
132
|
-
expert_scale,
|
|
133
|
-
gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
|
|
134
|
-
use_fast_accum=True,
|
|
135
|
-
)
|
|
222
|
+
up = torch.ops.fbgemm.f8f8bf16_rowwise(
|
|
223
|
+
expert_quantized,
|
|
224
|
+
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
|
|
225
|
+
expert_scale,
|
|
226
|
+
gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
|
|
227
|
+
use_fast_accum=True,
|
|
228
|
+
)
|
|
136
229
|
|
|
137
230
|
activated = up * self.act_fn(gate)
|
|
138
231
|
|
|
139
|
-
activated_quantized, activated_scale =
|
|
140
|
-
activated, num_tokens, self.input_scale_ub
|
|
141
|
-
)
|
|
232
|
+
activated_quantized, activated_scale = quantize_fp8_per_row(activated, num_tokens, self.input_scale_ub)
|
|
142
233
|
|
|
143
234
|
down_proj_scale_float32 = self.down_proj_scale.to(torch.float32)
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
235
|
+
if _is_torch_xpu_available:
|
|
236
|
+
expert_output = torch._scaled_mm(
|
|
237
|
+
activated_quantized,
|
|
238
|
+
self.down_proj[i].transpose(0, 1).contiguous(),
|
|
239
|
+
scale_a=activated_scale.unsqueeze(-1),
|
|
240
|
+
scale_b=down_proj_scale_float32[i].view(-1, 1).contiguous().t(),
|
|
241
|
+
out_dtype=hidden_states.dtype,
|
|
242
|
+
)
|
|
243
|
+
else:
|
|
244
|
+
expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
|
|
245
|
+
activated_quantized,
|
|
246
|
+
self.down_proj[i].transpose(0, 1).contiguous(),
|
|
247
|
+
activated_scale,
|
|
248
|
+
down_proj_scale_float32[i].view(-1, 1).contiguous(),
|
|
249
|
+
use_fast_accum=True,
|
|
250
|
+
)
|
|
151
251
|
|
|
152
252
|
next_states[i] = expert_output
|
|
153
253
|
next_states = next_states.to(hidden_states.device)
|
|
154
254
|
return next_states.view(-1, self.hidden_size)
|
|
155
255
|
|
|
156
256
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
quantization_config=None,
|
|
162
|
-
has_been_replaced=False,
|
|
163
|
-
pre_quantized=False,
|
|
164
|
-
config=None,
|
|
165
|
-
tp_plan=None,
|
|
166
|
-
):
|
|
167
|
-
"""
|
|
168
|
-
Private method that wraps the recursion for module replacement.
|
|
169
|
-
|
|
170
|
-
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
|
|
171
|
-
"""
|
|
172
|
-
|
|
173
|
-
import re
|
|
174
|
-
|
|
175
|
-
if current_key_name is None:
|
|
176
|
-
current_key_name = []
|
|
177
|
-
|
|
178
|
-
for name, module in model.named_children():
|
|
179
|
-
current_key_name.append(name)
|
|
180
|
-
|
|
181
|
-
if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
|
|
182
|
-
# Check if the current key is not in the `modules_to_not_convert`
|
|
183
|
-
current_key_name_str = ".".join(current_key_name)
|
|
184
|
-
if not any(
|
|
185
|
-
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
|
186
|
-
):
|
|
187
|
-
with init_empty_weights(include_buffers=True):
|
|
188
|
-
in_features = module.in_features
|
|
189
|
-
out_features = module.out_features
|
|
190
|
-
model._modules[name] = FbgemmFp8Linear(
|
|
191
|
-
in_features,
|
|
192
|
-
out_features,
|
|
193
|
-
module.bias is not None,
|
|
194
|
-
)
|
|
195
|
-
has_been_replaced = True
|
|
196
|
-
|
|
197
|
-
# Force requires grad to False to avoid unexpected errors
|
|
198
|
-
model._modules[name].requires_grad_(False)
|
|
199
|
-
# set non persistent buffer outside of init_empty_weights
|
|
200
|
-
model._modules[name].input_scale_ub = torch.tensor(
|
|
201
|
-
[quantization_config.activation_scale_ub],
|
|
202
|
-
dtype=torch.float,
|
|
203
|
-
)
|
|
204
|
-
if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert:
|
|
205
|
-
current_key_name_str = ".".join(current_key_name)
|
|
206
|
-
if not any(
|
|
207
|
-
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
|
208
|
-
):
|
|
209
|
-
with init_empty_weights(include_buffers=True):
|
|
210
|
-
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
|
|
211
|
-
model._modules[name] = FbgemmFp8Llama4TextExperts(
|
|
212
|
-
config.text_config,
|
|
213
|
-
)
|
|
214
|
-
model._modules[name].input_scale_ub = torch.tensor(
|
|
215
|
-
[quantization_config.activation_scale_ub], dtype=torch.float
|
|
216
|
-
)
|
|
257
|
+
@lru_cache(maxsize=1)
|
|
258
|
+
def get_quantize_fp8_per_row():
|
|
259
|
+
if _is_torch_xpu_available:
|
|
260
|
+
from kernels import get_kernel
|
|
217
261
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
module,
|
|
221
|
-
modules_to_not_convert,
|
|
222
|
-
current_key_name,
|
|
223
|
-
quantization_config,
|
|
224
|
-
has_been_replaced=has_been_replaced,
|
|
225
|
-
pre_quantized=pre_quantized,
|
|
226
|
-
config=config,
|
|
227
|
-
tp_plan=tp_plan,
|
|
228
|
-
)
|
|
229
|
-
# Remove the last key for recursion
|
|
230
|
-
current_key_name.pop(-1)
|
|
231
|
-
return model, has_been_replaced
|
|
262
|
+
return get_kernel("kernels-community/fp8-fbgemm").quantize_fp8_per_row
|
|
263
|
+
return torch.ops.fbgemm.quantize_fp8_per_row
|
|
232
264
|
|
|
233
265
|
|
|
234
266
|
def replace_with_fbgemm_fp8_linear(
|
|
235
|
-
model,
|
|
236
|
-
modules_to_not_convert=None,
|
|
237
|
-
current_key_name=None,
|
|
238
|
-
quantization_config=None,
|
|
239
|
-
pre_quantized=False,
|
|
240
|
-
config=None,
|
|
241
|
-
tp_plan=None,
|
|
267
|
+
model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False, tp_plan=None
|
|
242
268
|
):
|
|
243
269
|
"""
|
|
244
270
|
A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
|
|
245
271
|
This will enable running your models using high performance fp8 kernel from FBGEMM library.
|
|
246
272
|
|
|
247
|
-
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
|
|
248
|
-
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
|
|
249
|
-
CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
|
|
250
|
-
|
|
251
273
|
Parameters:
|
|
252
274
|
model (`torch.nn.Module`):
|
|
253
275
|
Input model or `torch.nn.Module` as the function is run recursively.
|
|
254
|
-
modules_to_not_convert (`list[`str`]`, *optional*, defaults to `
|
|
255
|
-
Names of the modules to not convert
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
`disk`).
|
|
276
|
+
modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
|
|
277
|
+
Names of the modules to not convert. In practice we keep the `lm_head` in full precision for numerical stability reasons.
|
|
278
|
+
quantization_config (`FbgemmFp8Config`):
|
|
279
|
+
The quantization config object that contains the quantization parameters.
|
|
280
|
+
pre_quantized (`book`, defaults to `False`):
|
|
281
|
+
Whether the model is pre-quantized or not
|
|
261
282
|
"""
|
|
283
|
+
global quantize_fp8_per_row
|
|
284
|
+
quantize_fp8_per_row = get_quantize_fp8_per_row()
|
|
285
|
+
|
|
286
|
+
has_been_replaced = False
|
|
287
|
+
module_kwargs = {} if pre_quantized else {"dtype": None}
|
|
288
|
+
|
|
289
|
+
for module_name, module in model.named_modules():
|
|
290
|
+
if not should_convert_module(module_name, modules_to_not_convert):
|
|
291
|
+
continue
|
|
292
|
+
|
|
293
|
+
new_module = None
|
|
294
|
+
with init_empty_weights(include_buffers=True):
|
|
295
|
+
if module.__class__.__name__ == "Llama4TextExperts":
|
|
296
|
+
# TODO: make sure tp works later
|
|
297
|
+
# if tp_plan is not None:
|
|
298
|
+
# tp_key = re.sub(r"\d+", "*", f"{module_name}.down_proj_scale")
|
|
299
|
+
# tp_plan[tp_key] = None
|
|
300
|
+
text_config = getattr(model.config, "text_config", model.config)
|
|
301
|
+
new_module = FbgemmFp8Llama4TextExperts(text_config or model.config)
|
|
302
|
+
elif isinstance(module, nn.Linear):
|
|
303
|
+
new_module = FbgemmFp8Linear(
|
|
304
|
+
module.in_features,
|
|
305
|
+
module.out_features,
|
|
306
|
+
module.bias is not None,
|
|
307
|
+
**module_kwargs,
|
|
308
|
+
)
|
|
309
|
+
new_module.requires_grad_(False)
|
|
310
|
+
|
|
311
|
+
if new_module is None:
|
|
312
|
+
continue
|
|
313
|
+
|
|
314
|
+
if hasattr(new_module, "input_scale_ub"):
|
|
315
|
+
new_module.input_scale_ub = torch.tensor(
|
|
316
|
+
[quantization_config.activation_scale_ub],
|
|
317
|
+
dtype=torch.float,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
model.set_submodule(module_name, new_module)
|
|
321
|
+
has_been_replaced = True
|
|
262
322
|
|
|
263
|
-
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
|
264
|
-
|
|
265
|
-
if quantization_config.modules_to_not_convert is not None:
|
|
266
|
-
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
|
267
|
-
modules_to_not_convert = list(set(modules_to_not_convert))
|
|
268
|
-
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
|
|
269
|
-
model,
|
|
270
|
-
modules_to_not_convert,
|
|
271
|
-
current_key_name,
|
|
272
|
-
quantization_config,
|
|
273
|
-
pre_quantized=pre_quantized,
|
|
274
|
-
config=config,
|
|
275
|
-
tp_plan=tp_plan,
|
|
276
|
-
)
|
|
277
323
|
if not has_been_replaced:
|
|
278
324
|
logger.warning(
|
|
279
325
|
"You are loading your model using FP8 quantization but no linear modules were found in your model."
|