transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -21,7 +21,6 @@ import inspect
|
|
|
21
21
|
import os
|
|
22
22
|
import re
|
|
23
23
|
from collections import OrderedDict, defaultdict
|
|
24
|
-
from contextlib import contextmanager
|
|
25
24
|
from typing import TYPE_CHECKING
|
|
26
25
|
|
|
27
26
|
from safetensors import safe_open
|
|
@@ -55,114 +54,6 @@ if TYPE_CHECKING:
|
|
|
55
54
|
logger = logging.get_logger(__name__)
|
|
56
55
|
|
|
57
56
|
|
|
58
|
-
@contextmanager
|
|
59
|
-
def init_empty_weights(include_buffers: bool = False):
|
|
60
|
-
"""
|
|
61
|
-
A context manager under which models are initialized with all parameters on the meta device, therefore creating an
|
|
62
|
-
empty model. Useful when just initializing the model would blow the available RAM.
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
include_buffers (`bool`, *optional*):
|
|
66
|
-
Whether or not to also put all buffers on the meta device while initializing.
|
|
67
|
-
|
|
68
|
-
Example:
|
|
69
|
-
|
|
70
|
-
```python
|
|
71
|
-
import torch.nn as nn
|
|
72
|
-
from accelerate import init_empty_weights
|
|
73
|
-
|
|
74
|
-
# Initialize a model with 100 billions parameters in no time and without using any RAM.
|
|
75
|
-
with init_empty_weights():
|
|
76
|
-
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
|
77
|
-
```
|
|
78
|
-
|
|
79
|
-
<Tip warning={true}>
|
|
80
|
-
|
|
81
|
-
Any model created under this context manager has no weights. As such you can't do something like
|
|
82
|
-
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
|
|
83
|
-
Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
|
|
84
|
-
called.
|
|
85
|
-
|
|
86
|
-
</Tip>
|
|
87
|
-
"""
|
|
88
|
-
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
|
|
89
|
-
yield f
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
@contextmanager
|
|
93
|
-
def init_on_device(device: "torch.device", include_buffers: bool = False):
|
|
94
|
-
"""
|
|
95
|
-
A context manager under which models are initialized with all parameters on the specified device.
|
|
96
|
-
|
|
97
|
-
Args:
|
|
98
|
-
device (`torch.device`):
|
|
99
|
-
Device to initialize all parameters on.
|
|
100
|
-
include_buffers (`bool`, *optional*):
|
|
101
|
-
Whether or not to also put all buffers on the meta device while initializing.
|
|
102
|
-
|
|
103
|
-
Example:
|
|
104
|
-
|
|
105
|
-
```python
|
|
106
|
-
import torch.nn as nn
|
|
107
|
-
from accelerate import init_on_device
|
|
108
|
-
|
|
109
|
-
with init_on_device(device=torch.device("cuda")):
|
|
110
|
-
tst = nn.Linear(100, 100) # on `cuda` device
|
|
111
|
-
```
|
|
112
|
-
"""
|
|
113
|
-
if include_buffers:
|
|
114
|
-
with device:
|
|
115
|
-
yield
|
|
116
|
-
return
|
|
117
|
-
|
|
118
|
-
old_register_parameter = nn.Module.register_parameter
|
|
119
|
-
if include_buffers:
|
|
120
|
-
old_register_buffer = nn.Module.register_buffer
|
|
121
|
-
|
|
122
|
-
def register_empty_parameter(module, name, param):
|
|
123
|
-
old_register_parameter(module, name, param)
|
|
124
|
-
if param is not None:
|
|
125
|
-
param_cls = type(module._parameters[name])
|
|
126
|
-
kwargs = module._parameters[name].__dict__
|
|
127
|
-
kwargs["requires_grad"] = param.requires_grad
|
|
128
|
-
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
|
129
|
-
|
|
130
|
-
def register_empty_buffer(module, name, buffer, persistent=True):
|
|
131
|
-
old_register_buffer(module, name, buffer, persistent=persistent)
|
|
132
|
-
if buffer is not None:
|
|
133
|
-
module._buffers[name] = module._buffers[name].to(device)
|
|
134
|
-
|
|
135
|
-
# Patch tensor creation
|
|
136
|
-
if include_buffers:
|
|
137
|
-
tensor_constructors_to_patch = {
|
|
138
|
-
torch_function_name: getattr(torch, torch_function_name)
|
|
139
|
-
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
|
140
|
-
}
|
|
141
|
-
else:
|
|
142
|
-
tensor_constructors_to_patch = {}
|
|
143
|
-
|
|
144
|
-
def patch_tensor_constructor(fn):
|
|
145
|
-
def wrapper(*args, **kwargs):
|
|
146
|
-
kwargs["device"] = device
|
|
147
|
-
return fn(*args, **kwargs)
|
|
148
|
-
|
|
149
|
-
return wrapper
|
|
150
|
-
|
|
151
|
-
try:
|
|
152
|
-
nn.Module.register_parameter = register_empty_parameter
|
|
153
|
-
if include_buffers:
|
|
154
|
-
nn.Module.register_buffer = register_empty_buffer
|
|
155
|
-
for torch_function_name in tensor_constructors_to_patch:
|
|
156
|
-
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
|
157
|
-
yield
|
|
158
|
-
finally:
|
|
159
|
-
nn.Module.register_parameter = old_register_parameter
|
|
160
|
-
if include_buffers:
|
|
161
|
-
nn.Module.register_buffer = old_register_buffer
|
|
162
|
-
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
|
163
|
-
setattr(torch, torch_function_name, old_torch_function)
|
|
164
|
-
|
|
165
|
-
|
|
166
57
|
def check_and_set_device_map(device_map: "torch.device | int | str | dict | None") -> dict | str | None:
|
|
167
58
|
from ..modeling_utils import get_torch_context_manager_or_global_device
|
|
168
59
|
|
|
@@ -182,6 +73,10 @@ def check_and_set_device_map(device_map: "torch.device | int | str | dict | None
|
|
|
182
73
|
device_map = {"": device_map}
|
|
183
74
|
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
|
184
75
|
try:
|
|
76
|
+
if device_map == "cuda":
|
|
77
|
+
# setting to the local rank
|
|
78
|
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
79
|
+
device_map = f"cuda:{local_rank}"
|
|
185
80
|
device_map = {"": torch.device(device_map)}
|
|
186
81
|
except RuntimeError:
|
|
187
82
|
raise ValueError(
|
|
@@ -398,7 +293,7 @@ def _get_device_map(
|
|
|
398
293
|
# especially if the model uses WeightConverter (because there will be some uncontrollable cpu memory spikes during
|
|
399
294
|
# the conversions before we resave the weights). In those cases, it's better to offload to disk a bit more
|
|
400
295
|
# if we were in-between, as otherwise we blow-up cpu memory
|
|
401
|
-
if max_memory is None:
|
|
296
|
+
if max_memory is None and "cpu" in inferred_max_memory:
|
|
402
297
|
inferred_max_memory["cpu"] *= 0.90
|
|
403
298
|
|
|
404
299
|
if hf_quantizer is not None:
|
|
@@ -458,10 +353,13 @@ def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload
|
|
|
458
353
|
dispatch_model(model, **device_map_kwargs)
|
|
459
354
|
|
|
460
355
|
|
|
461
|
-
def expand_device_map(device_map, param_names):
|
|
356
|
+
def expand_device_map(device_map: dict | None, param_names: list[str]):
|
|
462
357
|
"""
|
|
463
358
|
Expand a device map to return the correspondence parameter name to device.
|
|
464
359
|
"""
|
|
360
|
+
if device_map is None:
|
|
361
|
+
return dict.fromkeys(param_names, "cpu")
|
|
362
|
+
|
|
465
363
|
# Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
|
|
466
364
|
device_map_regex = re.compile(
|
|
467
365
|
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
|
|
@@ -474,6 +372,15 @@ def expand_device_map(device_map, param_names):
|
|
|
474
372
|
return new_device_map
|
|
475
373
|
|
|
476
374
|
|
|
375
|
+
def get_device(device_map: dict | None, param_name: str, valid_torch_device: bool = False) -> torch.device | str | int:
|
|
376
|
+
"""Return the device on which `param_name` should be according to the `device_map`. If `valid_torch_device` is `True`,
|
|
377
|
+
then if the device is `"disk"`, `"cpu"` will be returned instead."""
|
|
378
|
+
device = expand_device_map(device_map, [param_name])[param_name]
|
|
379
|
+
if valid_torch_device and device == "disk":
|
|
380
|
+
return "cpu"
|
|
381
|
+
return device
|
|
382
|
+
|
|
383
|
+
|
|
477
384
|
def accelerate_disk_offload(
|
|
478
385
|
model: "PreTrainedModel",
|
|
479
386
|
disk_offload_folder: str | None,
|
|
@@ -554,6 +461,32 @@ def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str |
|
|
|
554
461
|
return offload_index
|
|
555
462
|
|
|
556
463
|
|
|
464
|
+
def load_offloaded_parameter(model: "PreTrainedModel", param_name: str) -> torch.Tensor:
|
|
465
|
+
"""Load `param_name` from disk, if it was offloaded due to the device_map, and thus lives as a meta parameter
|
|
466
|
+
inside `model`.
|
|
467
|
+
This is needed when resaving a model, when some parameters were offloaded (we need to load them from disk, to
|
|
468
|
+
then resave them to disk in the correct shard...)."""
|
|
469
|
+
# Start from the most inner module, and try to find the hook that was used for offloading the param
|
|
470
|
+
module_parts = param_name.split(".")
|
|
471
|
+
modules_to_check = [".".join(module_parts[:-idx]) for idx in range(1, len(module_parts))] + [""]
|
|
472
|
+
for parent_name in modules_to_check:
|
|
473
|
+
parent = model.get_submodule(parent_name)
|
|
474
|
+
if hasattr(parent, "_hf_hook"):
|
|
475
|
+
weights_map = parent._hf_hook.weights_map
|
|
476
|
+
truncated_param_name = param_name.replace(f"{parent_name}." if parent_name != "" else parent_name, "")
|
|
477
|
+
break
|
|
478
|
+
# If we did not break the loop, something is wrong
|
|
479
|
+
else:
|
|
480
|
+
raise ValueError(
|
|
481
|
+
f"{param_name} is on the meta device because it was offloaded, but we could not find "
|
|
482
|
+
"the corresponding hook for it"
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# This call loads it from disk
|
|
486
|
+
tensor = weights_map[truncated_param_name]
|
|
487
|
+
return tensor
|
|
488
|
+
|
|
489
|
+
|
|
557
490
|
def _init_infer_auto_device_map(
|
|
558
491
|
model: nn.Module,
|
|
559
492
|
max_memory: dict[int | str, int | str] | None = None,
|
|
@@ -14,13 +14,11 @@
|
|
|
14
14
|
"AQLM (Additive Quantization of Language Model) integration file"
|
|
15
15
|
|
|
16
16
|
from ..quantizers.quantizers_utils import should_convert_module
|
|
17
|
-
from ..utils import
|
|
17
|
+
from ..utils import is_torch_available, logging
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
if is_accelerate_available():
|
|
21
|
-
from accelerate import init_empty_weights
|
|
22
|
-
|
|
23
20
|
if is_torch_available():
|
|
21
|
+
import torch
|
|
24
22
|
import torch.nn as nn
|
|
25
23
|
|
|
26
24
|
logger = logging.get_logger(__name__)
|
|
@@ -46,7 +44,7 @@ def replace_with_aqlm_linear(model, modules_to_not_convert: list[str] | None = N
|
|
|
46
44
|
for module_name, module in model.named_modules():
|
|
47
45
|
if not should_convert_module(module_name, modules_to_not_convert):
|
|
48
46
|
continue
|
|
49
|
-
with
|
|
47
|
+
with torch.device("meta"):
|
|
50
48
|
if isinstance(module, nn.Linear):
|
|
51
49
|
new_module = QuantizedLinear(
|
|
52
50
|
module.in_features,
|
transformers/integrations/awq.py
CHANGED
|
@@ -16,12 +16,9 @@
|
|
|
16
16
|
from typing import Optional, Union
|
|
17
17
|
|
|
18
18
|
from ..quantizers.quantizers_utils import should_convert_module
|
|
19
|
-
from ..utils import
|
|
19
|
+
from ..utils import is_torch_available, logging
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
if is_accelerate_available():
|
|
23
|
-
from accelerate import init_empty_weights
|
|
24
|
-
|
|
25
22
|
if is_torch_available():
|
|
26
23
|
import torch
|
|
27
24
|
import torch.nn as nn
|
|
@@ -97,7 +94,7 @@ def replace_with_awq_linear(
|
|
|
97
94
|
for module_name, module in model.named_modules():
|
|
98
95
|
if not should_convert_module(module_name, modules_to_not_convert):
|
|
99
96
|
continue
|
|
100
|
-
with
|
|
97
|
+
with torch.device("meta"):
|
|
101
98
|
if isinstance(module, nn.Linear):
|
|
102
99
|
new_module = target_cls(
|
|
103
100
|
bits=quantization_config.bits,
|
|
@@ -1,10 +1,7 @@
|
|
|
1
1
|
from ..quantizers.quantizers_utils import should_convert_module
|
|
2
|
-
from ..utils import
|
|
2
|
+
from ..utils import is_torch_available, logging
|
|
3
3
|
|
|
4
4
|
|
|
5
|
-
if is_accelerate_available():
|
|
6
|
-
from accelerate import init_empty_weights
|
|
7
|
-
|
|
8
5
|
if is_torch_available():
|
|
9
6
|
import torch
|
|
10
7
|
import torch.nn as nn
|
|
@@ -92,7 +89,7 @@ def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
|
92
89
|
|
|
93
90
|
Explanation of the example:
|
|
94
91
|
---------------------------
|
|
95
|
-
Let's take the first value for example 0b10100001, we
|
|
92
|
+
Let's take the first value for example 0b10100001, we will only focus on the first column,
|
|
96
93
|
because every element is unpacked across the first dimension
|
|
97
94
|
- First 2 bits: `01` → 0 at [0][0]
|
|
98
95
|
- Second 2 bits: `00` → -1 at [0][2]
|
|
@@ -173,7 +170,7 @@ class BitLinear(nn.Module):
|
|
|
173
170
|
Activation function : Performs symmetric, per-token quantization on the input activations.
|
|
174
171
|
Parameters:
|
|
175
172
|
-----------
|
|
176
|
-
|
|
173
|
+
input : torch.Tensor
|
|
177
174
|
Input activations to be quantized.
|
|
178
175
|
num_bits : int, optional (default=8)
|
|
179
176
|
Number of bits to use for quantization, determining the quantization range.
|
|
@@ -334,7 +331,7 @@ def replace_with_bitnet_linear(model, modules_to_not_convert: list[str] | None =
|
|
|
334
331
|
for module_name, module in model.named_modules():
|
|
335
332
|
if not should_convert_module(module_name, modules_to_not_convert):
|
|
336
333
|
continue
|
|
337
|
-
with
|
|
334
|
+
with torch.device("meta"):
|
|
338
335
|
if isinstance(module, nn.Linear):
|
|
339
336
|
if quantization_config and quantization_config.linear_class == "autobitlinear":
|
|
340
337
|
new_module = AutoBitLinear(
|
|
@@ -365,7 +362,7 @@ def replace_with_bitnet_linear(model, modules_to_not_convert: list[str] | None =
|
|
|
365
362
|
|
|
366
363
|
if not has_been_replaced:
|
|
367
364
|
logger.warning(
|
|
368
|
-
"You are loading your model using
|
|
365
|
+
"You are loading your model using bitnet but no linear modules were found in your model."
|
|
369
366
|
" Please double check your model architecture, or submit an issue on github if you think this is"
|
|
370
367
|
" a bug."
|
|
371
368
|
)
|
|
@@ -22,7 +22,6 @@ if is_torch_available():
|
|
|
22
22
|
|
|
23
23
|
if is_accelerate_available():
|
|
24
24
|
import accelerate
|
|
25
|
-
from accelerate import init_empty_weights
|
|
26
25
|
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
|
27
26
|
|
|
28
27
|
logger = logging.get_logger(__name__)
|
|
@@ -181,7 +180,7 @@ def replace_with_bnb_linear(
|
|
|
181
180
|
if not should_convert_module(module_name, modules_to_not_convert):
|
|
182
181
|
continue
|
|
183
182
|
new_module = None
|
|
184
|
-
with
|
|
183
|
+
with torch.device("meta"):
|
|
185
184
|
if isinstance(module, (nn.Linear, Conv1D)):
|
|
186
185
|
if isinstance(module, Conv1D):
|
|
187
186
|
in_features, out_features = module.weight.shape
|
|
@@ -233,7 +232,7 @@ def replace_with_bnb_linear(
|
|
|
233
232
|
|
|
234
233
|
|
|
235
234
|
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
|
|
236
|
-
def dequantize_bnb_weight(weight: "torch.nn.Parameter",
|
|
235
|
+
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
|
|
237
236
|
"""
|
|
238
237
|
Helper function to dequantize 4bit or 8bit bnb weights.
|
|
239
238
|
|
|
@@ -248,10 +247,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", st
|
|
|
248
247
|
|
|
249
248
|
if cls_name == "Params4bit":
|
|
250
249
|
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
|
|
251
|
-
|
|
252
|
-
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
|
|
253
|
-
)
|
|
254
|
-
return output_tensor.to(dtype)
|
|
250
|
+
return output_tensor
|
|
255
251
|
|
|
256
252
|
if state.SCB is None:
|
|
257
253
|
state.SCB = weight.SCB
|
|
@@ -263,7 +259,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", st
|
|
|
263
259
|
# Multiply by (scale/127) to dequantize.
|
|
264
260
|
dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
|
|
265
261
|
|
|
266
|
-
return dequantized
|
|
262
|
+
return dequantized
|
|
267
263
|
|
|
268
264
|
|
|
269
265
|
def _create_accelerate_new_hook(old_hook):
|
|
@@ -283,10 +279,7 @@ def _create_accelerate_new_hook(old_hook):
|
|
|
283
279
|
return new_hook
|
|
284
280
|
|
|
285
281
|
|
|
286
|
-
def dequantize_and_replace(
|
|
287
|
-
model,
|
|
288
|
-
quantization_config=None,
|
|
289
|
-
):
|
|
282
|
+
def dequantize_and_replace(model, quantization_config=None, dtype=None):
|
|
290
283
|
"""
|
|
291
284
|
Converts a quantized model into its dequantized original version. The newly converted model will have
|
|
292
285
|
some performance drop compared to the original model before quantization - use it only for specific usecases
|
|
@@ -297,14 +290,22 @@ def dequantize_and_replace(
|
|
|
297
290
|
quant_method = quantization_config.quantization_method()
|
|
298
291
|
|
|
299
292
|
target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
|
|
300
|
-
|
|
301
293
|
for module_name, module in model.named_modules():
|
|
302
294
|
if isinstance(module, target_cls):
|
|
303
|
-
with
|
|
295
|
+
with torch.device("meta"):
|
|
304
296
|
bias = getattr(module, "bias", None)
|
|
305
297
|
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
|
|
306
298
|
state = module.state if quant_method == "llm_int8" else None
|
|
307
|
-
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight,
|
|
299
|
+
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
|
|
300
|
+
weight = dequantize_bnb_weight(module.weight, state)
|
|
301
|
+
if dtype is None:
|
|
302
|
+
logger.warning_once(
|
|
303
|
+
f"The modules are dequantized in {weight.dtype}. If you want to change the dtype, please specify `dtype` in `dequantize`. "
|
|
304
|
+
)
|
|
305
|
+
else:
|
|
306
|
+
logger.warning_once(f"The modules are dequantized in {weight.dtype} and casted to {dtype}.")
|
|
307
|
+
weight = weight.to(dtype)
|
|
308
|
+
new_module.weight = torch.nn.Parameter(weight)
|
|
308
309
|
if bias is not None:
|
|
309
310
|
new_module.bias = bias
|
|
310
311
|
if hasattr(module, "_hf_hook"):
|
|
@@ -304,6 +304,15 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict):
|
|
|
304
304
|
state_dict._metadata = metadata
|
|
305
305
|
|
|
306
306
|
error_msgs = []
|
|
307
|
+
meta_model_state_dict = model_to_load.state_dict()
|
|
308
|
+
missing_keys = set(meta_model_state_dict.keys())
|
|
309
|
+
|
|
310
|
+
prefix_model = getattr(model_to_load, "base_model_prefix", None)
|
|
311
|
+
# take care of the case where in the checkpoint we don't have the prefix
|
|
312
|
+
state_dict = {
|
|
313
|
+
(f"{prefix_model}.{k}" if meta_model_state_dict.get(f"{prefix_model}.{k}") is not None else k): v
|
|
314
|
+
for k, v in state_dict.items()
|
|
315
|
+
}
|
|
307
316
|
|
|
308
317
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
|
309
318
|
# so we need to apply the function recursively.
|
|
@@ -320,7 +329,14 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict):
|
|
|
320
329
|
# In sharded models, each shard has only part of the full state_dict, so only gather
|
|
321
330
|
# parameters that are in the current state_dict.
|
|
322
331
|
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
|
|
323
|
-
params_to_gather = [
|
|
332
|
+
params_to_gather = []
|
|
333
|
+
for k in named_parameters:
|
|
334
|
+
if k in state_dict:
|
|
335
|
+
param = named_parameters[k]
|
|
336
|
+
# crutial to not init the weight again
|
|
337
|
+
param._is_hf_initialized = True
|
|
338
|
+
params_to_gather.append(param)
|
|
339
|
+
missing_keys.discard(k)
|
|
324
340
|
|
|
325
341
|
if len(params_to_gather) > 0:
|
|
326
342
|
# because zero3 puts placeholders in model params, this context
|
|
@@ -333,11 +349,10 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict):
|
|
|
333
349
|
for name, child in module._modules.items():
|
|
334
350
|
if child is not None:
|
|
335
351
|
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
|
|
336
|
-
child._is_hf_initialized = True
|
|
337
352
|
|
|
338
353
|
load(model_to_load, state_dict, assign_to_params_buffers=False)
|
|
339
354
|
|
|
340
|
-
return error_msgs
|
|
355
|
+
return error_msgs, missing_keys
|
|
341
356
|
|
|
342
357
|
|
|
343
358
|
def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
|
|
@@ -14,15 +14,13 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
from ..core_model_loading import ConversionOps
|
|
16
16
|
from ..quantizers.quantizers_utils import should_convert_module
|
|
17
|
-
from ..utils import
|
|
17
|
+
from ..utils import is_torch_available, logging
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
if is_torch_available():
|
|
21
21
|
import torch
|
|
22
22
|
import torch.nn as nn
|
|
23
23
|
|
|
24
|
-
if is_accelerate_available():
|
|
25
|
-
from accelerate import init_empty_weights
|
|
26
24
|
|
|
27
25
|
logger = logging.get_logger(__name__)
|
|
28
26
|
|
|
@@ -97,7 +95,7 @@ def replace_with_eetq_linear(model, modules_to_not_convert: list[str] | None = N
|
|
|
97
95
|
Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
|
|
98
96
|
for numerical stability reasons.
|
|
99
97
|
"""
|
|
100
|
-
from
|
|
98
|
+
from .hub_kernels import get_kernel
|
|
101
99
|
|
|
102
100
|
global eetq_kernels_hub
|
|
103
101
|
eetq_kernels_hub = get_kernel("kernels-community/quantization-eetq")
|
|
@@ -108,7 +106,7 @@ def replace_with_eetq_linear(model, modules_to_not_convert: list[str] | None = N
|
|
|
108
106
|
for module_name, module in model.named_modules():
|
|
109
107
|
if not should_convert_module(module_name, modules_to_not_convert):
|
|
110
108
|
continue
|
|
111
|
-
with
|
|
109
|
+
with torch.device("meta"):
|
|
112
110
|
if isinstance(module, nn.Linear):
|
|
113
111
|
new_module = EetqLinear(
|
|
114
112
|
module.in_features, module.out_features, bias=module.bias is not None, **module_kwargs
|
|
@@ -257,7 +257,7 @@ class FbgemmFp8Llama4TextExperts(nn.Module):
|
|
|
257
257
|
@lru_cache(maxsize=1)
|
|
258
258
|
def get_quantize_fp8_per_row():
|
|
259
259
|
if _is_torch_xpu_available:
|
|
260
|
-
from
|
|
260
|
+
from .hub_kernels import get_kernel
|
|
261
261
|
|
|
262
262
|
return get_kernel("kernels-community/fp8-fbgemm").quantize_fp8_per_row
|
|
263
263
|
return torch.ops.fbgemm.quantize_fp8_per_row
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
from ..core_model_loading import ConversionOps
|
|
17
17
|
from ..quantizers.quantizers_utils import should_convert_module
|
|
18
|
-
from ..utils import
|
|
18
|
+
from ..utils import is_torch_accelerator_available, is_torch_available, logging
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
if is_torch_available():
|
|
@@ -25,23 +25,16 @@ if is_torch_available():
|
|
|
25
25
|
import triton.language as tl
|
|
26
26
|
from torch.nn import functional as F
|
|
27
27
|
|
|
28
|
-
if is_accelerate_available():
|
|
29
|
-
from accelerate import init_empty_weights
|
|
30
|
-
|
|
31
28
|
|
|
32
29
|
logger = logging.get_logger(__name__)
|
|
33
30
|
try:
|
|
34
31
|
_FP8_DTYPE = torch.float8_e4m3fn
|
|
35
32
|
_FP8_MIN = torch.finfo(_FP8_DTYPE).min
|
|
36
33
|
_FP8_MAX = torch.finfo(_FP8_DTYPE).max
|
|
37
|
-
_FP8_IS_INT = False
|
|
38
34
|
except AttributeError:
|
|
39
|
-
_FP8_DTYPE =
|
|
40
|
-
_FP8_MIN, _FP8_MAX = -
|
|
41
|
-
|
|
42
|
-
logger.warning_once(
|
|
43
|
-
"torch.float8_e4m3fn not available; falling back to int8 emulation for Fp8Quantize operations."
|
|
44
|
-
)
|
|
35
|
+
_FP8_DTYPE = None
|
|
36
|
+
_FP8_MIN, _FP8_MAX = -448, 448
|
|
37
|
+
logger.warning_once("torch.float8_e4m3fn not available")
|
|
45
38
|
|
|
46
39
|
|
|
47
40
|
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
|
@@ -618,7 +611,7 @@ def replace_with_fp8_linear(
|
|
|
618
611
|
# we need this to correctly materialize the weights during quantization
|
|
619
612
|
module_kwargs = {} if pre_quantized else {"dtype": None}
|
|
620
613
|
new_module = None
|
|
621
|
-
with
|
|
614
|
+
with torch.device("meta"):
|
|
622
615
|
if module_name.endswith(".experts"):
|
|
623
616
|
new_module = FP8Expert(
|
|
624
617
|
config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs
|
|
@@ -701,10 +694,7 @@ class Fp8Quantize(ConversionOps):
|
|
|
701
694
|
scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
|
|
702
695
|
scaled = reshaped * scales_broadcast
|
|
703
696
|
|
|
704
|
-
|
|
705
|
-
quantized = torch.clamp(scaled.round(), min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
|
|
706
|
-
else:
|
|
707
|
-
quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
|
|
697
|
+
quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
|
|
708
698
|
|
|
709
699
|
quantized = quantized.reshape(original_shape)
|
|
710
700
|
|
|
@@ -20,8 +20,8 @@ def get_target_dtype(query: torch.Tensor, module: torch.nn.Module) -> torch.dtyp
|
|
|
20
20
|
else torch.get_autocast_gpu_dtype()
|
|
21
21
|
)
|
|
22
22
|
# Handle the case where the model is quantized
|
|
23
|
-
elif hasattr(module.config, "
|
|
24
|
-
return module.config.
|
|
23
|
+
elif hasattr(module.config, "quantization_config"):
|
|
24
|
+
return module.config.dtype
|
|
25
25
|
else:
|
|
26
26
|
return next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
|
|
27
27
|
return None
|
|
@@ -16,12 +16,9 @@
|
|
|
16
16
|
from math import sqrt
|
|
17
17
|
|
|
18
18
|
from ..quantizers.quantizers_utils import should_convert_module
|
|
19
|
-
from ..utils import
|
|
19
|
+
from ..utils import is_flute_available, is_hadamard_available, is_torch_available, logging
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
if is_accelerate_available():
|
|
23
|
-
from accelerate import init_empty_weights
|
|
24
|
-
|
|
25
22
|
if is_torch_available():
|
|
26
23
|
import torch
|
|
27
24
|
import torch.nn as nn
|
|
@@ -569,7 +566,7 @@ def replace_with_higgs_linear(model, modules_to_not_convert: list[str] | None =
|
|
|
569
566
|
for module_name, module in model.named_modules():
|
|
570
567
|
if not should_convert_module(module_name, modules_to_not_convert):
|
|
571
568
|
continue
|
|
572
|
-
with
|
|
569
|
+
with torch.device("meta"):
|
|
573
570
|
if isinstance(module, nn.Linear):
|
|
574
571
|
new_module = HiggsLinear(
|
|
575
572
|
module.in_features,
|
|
@@ -11,11 +11,14 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import importlib.metadata
|
|
14
15
|
import os
|
|
15
16
|
import re
|
|
16
17
|
from collections.abc import Callable
|
|
17
18
|
from types import ModuleType
|
|
18
19
|
|
|
20
|
+
from packaging import version as pkg_version
|
|
21
|
+
|
|
19
22
|
from ..utils import ENV_VARS_TRUE_VALUES, logging
|
|
20
23
|
from ..utils.import_utils import is_kernels_available
|
|
21
24
|
from .flash_attention import flash_attention_forward
|
|
@@ -28,10 +31,12 @@ try:
|
|
|
28
31
|
Device,
|
|
29
32
|
LayerRepository,
|
|
30
33
|
Mode,
|
|
31
|
-
get_kernel,
|
|
32
34
|
register_kernel_mapping,
|
|
33
35
|
replace_kernel_forward_from_hub,
|
|
34
36
|
)
|
|
37
|
+
from kernels import (
|
|
38
|
+
get_kernel as get_kernel_hub,
|
|
39
|
+
)
|
|
35
40
|
from kernels import (
|
|
36
41
|
use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub,
|
|
37
42
|
)
|
|
@@ -340,8 +345,6 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
|
|
|
340
345
|
mapping[kernel_name] = None
|
|
341
346
|
return None
|
|
342
347
|
if _kernels_available:
|
|
343
|
-
from kernels import get_kernel
|
|
344
|
-
|
|
345
348
|
try:
|
|
346
349
|
repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"]
|
|
347
350
|
revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None)
|
|
@@ -370,7 +373,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
|
|
|
370
373
|
if callable(is_kernel_available) and is_kernel_available():
|
|
371
374
|
# Try to import the module "{kernel_name}" from parent package level
|
|
372
375
|
try:
|
|
373
|
-
module = importlib.import_module(f"{
|
|
376
|
+
module = importlib.import_module(f"{new_kernel_name}")
|
|
374
377
|
mapping[kernel_name] = module
|
|
375
378
|
return module
|
|
376
379
|
except Exception:
|
|
@@ -381,6 +384,20 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
|
|
|
381
384
|
return mapping[kernel_name]
|
|
382
385
|
|
|
383
386
|
|
|
387
|
+
def get_kernel(kernel_name: str, revision: str | None = None, version: str | None = None) -> ModuleType:
|
|
388
|
+
from .. import __version__
|
|
389
|
+
|
|
390
|
+
user_agent = {"framework": "transformers", "version": __version__, "repo_id": kernel_name}
|
|
391
|
+
if _kernels_available:
|
|
392
|
+
kernels_version = importlib.metadata.version("kernels")
|
|
393
|
+
if pkg_version.parse(kernels_version) >= pkg_version.parse("0.10.4"):
|
|
394
|
+
return get_kernel_hub(kernel_name, revision=revision, version=version, user_agent=user_agent)
|
|
395
|
+
else:
|
|
396
|
+
return get_kernel_hub(kernel_name, revision=revision)
|
|
397
|
+
else:
|
|
398
|
+
raise ImportError("kernels is not installed, please install it with `pip install kernels`")
|
|
399
|
+
|
|
400
|
+
|
|
384
401
|
def use_kernelized_func(module_names: list[Callable] | Callable):
|
|
385
402
|
"""
|
|
386
403
|
This decorator attaches the target function as an attribute of the module.
|
|
@@ -415,5 +432,6 @@ __all__ = [
|
|
|
415
432
|
"register_kernel_mapping_transformers",
|
|
416
433
|
"replace_kernel_forward_from_hub",
|
|
417
434
|
"lazy_load_kernel",
|
|
435
|
+
"get_kernel",
|
|
418
436
|
"use_kernelized_func",
|
|
419
|
-
]
|
|
437
|
+
] # type: ignore
|