transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -14,12 +14,14 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""PyTorch GPT-J model."""
|
|
16
16
|
|
|
17
|
+
import math
|
|
17
18
|
from typing import Optional, Union
|
|
18
19
|
|
|
19
20
|
import torch
|
|
20
21
|
from torch import nn
|
|
21
22
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
22
23
|
|
|
24
|
+
from ... import initialization as init
|
|
23
25
|
from ...activations import ACT2FN
|
|
24
26
|
from ...cache_utils import Cache, DynamicCache
|
|
25
27
|
from ...generation import GenerationMixin
|
|
@@ -77,7 +79,7 @@ class GPTJAttention(nn.Module):
|
|
|
77
79
|
def __init__(self, config, layer_idx=None):
|
|
78
80
|
super().__init__()
|
|
79
81
|
self.config = config
|
|
80
|
-
max_positions = config.max_position_embeddings
|
|
82
|
+
self.max_positions = config.max_position_embeddings
|
|
81
83
|
|
|
82
84
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
|
83
85
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
|
@@ -99,15 +101,17 @@ class GPTJAttention(nn.Module):
|
|
|
99
101
|
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
|
|
100
102
|
f" `num_attention_heads`: {self.num_attention_heads})."
|
|
101
103
|
)
|
|
102
|
-
self.scale_attn =
|
|
104
|
+
self.scale_attn = math.sqrt(self.head_dim)
|
|
103
105
|
|
|
104
106
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
105
107
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
106
108
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
107
109
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
108
110
|
self.rotary_dim = config.rotary_dim
|
|
109
|
-
pos_embd_dim = self.rotary_dim or self.embed_dim
|
|
110
|
-
self.
|
|
111
|
+
self.pos_embd_dim = self.rotary_dim or self.embed_dim
|
|
112
|
+
self.register_buffer(
|
|
113
|
+
"embed_positions", create_sinusoidal_positions(self.max_positions, self.pos_embd_dim), persistent=False
|
|
114
|
+
)
|
|
111
115
|
|
|
112
116
|
def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
|
|
113
117
|
"""
|
|
@@ -334,8 +338,8 @@ class GPTJFlashAttention2(GPTJAttention):
|
|
|
334
338
|
else torch.get_autocast_gpu_dtype()
|
|
335
339
|
)
|
|
336
340
|
# Handle the case where the model is quantized
|
|
337
|
-
elif hasattr(self.config, "
|
|
338
|
-
target_dtype = self.config.
|
|
341
|
+
elif hasattr(self.config, "quantization_config"):
|
|
342
|
+
target_dtype = self.config.dtype
|
|
339
343
|
else:
|
|
340
344
|
target_dtype = self.q_proj.weight.dtype
|
|
341
345
|
|
|
@@ -444,6 +448,11 @@ class GPTJPreTrainedModel(PreTrainedModel):
|
|
|
444
448
|
_supports_flash_attn = True
|
|
445
449
|
_can_compile_fullgraph = True
|
|
446
450
|
|
|
451
|
+
def _init_weights(self, module):
|
|
452
|
+
super()._init_weights(module)
|
|
453
|
+
if isinstance(module, GPTJAttention):
|
|
454
|
+
init.copy_(module.embed_positions, create_sinusoidal_positions(module.max_positions, module.pos_embd_dim))
|
|
455
|
+
|
|
447
456
|
|
|
448
457
|
@auto_docstring
|
|
449
458
|
class GPTJModel(GPTJPreTrainedModel):
|
|
@@ -337,7 +337,7 @@ class GraniteRotaryEmbedding(nn.Module):
|
|
|
337
337
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
338
338
|
|
|
339
339
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
340
|
-
self.original_inv_freq =
|
|
340
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
341
341
|
|
|
342
342
|
@staticmethod
|
|
343
343
|
def compute_default_rope_parameters(
|
|
@@ -293,6 +293,12 @@ class GraniteSpeechPreTrainedModel(PreTrainedModel):
|
|
|
293
293
|
super()._init_weights(module)
|
|
294
294
|
if isinstance(module, GraniteSpeechEncoderProjector):
|
|
295
295
|
init.normal_(module.query)
|
|
296
|
+
elif isinstance(module, GraniteSpeechCTCEncoder):
|
|
297
|
+
context_size = module.config.context_size
|
|
298
|
+
seq = torch.arange(context_size)
|
|
299
|
+
relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
|
|
300
|
+
attention_dists = torch.clamp(relpos_dist, -context_size, context_size) + module.config.max_pos_emb
|
|
301
|
+
init.copy_(module.attention_dists, attention_dists)
|
|
296
302
|
|
|
297
303
|
|
|
298
304
|
@auto_docstring(
|
|
@@ -322,6 +328,12 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
|
|
|
322
328
|
|
|
323
329
|
self.post_init()
|
|
324
330
|
|
|
331
|
+
def set_decoder(self, decoder):
|
|
332
|
+
self.language_model.set_decoder(decoder)
|
|
333
|
+
|
|
334
|
+
def get_decoder(self):
|
|
335
|
+
return self.language_model.get_decoder()
|
|
336
|
+
|
|
325
337
|
def set_input_embeddings(self, value):
|
|
326
338
|
self.language_model.set_input_embeddings(value)
|
|
327
339
|
|
|
@@ -458,6 +470,7 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
|
|
|
458
470
|
attention_mask=None,
|
|
459
471
|
cache_position=None,
|
|
460
472
|
logits_to_keep=None,
|
|
473
|
+
is_first_iteration=False,
|
|
461
474
|
**kwargs,
|
|
462
475
|
):
|
|
463
476
|
# Overwritten -- in specific circumstances we don't want to forward audio inputs to the model
|
|
@@ -469,13 +482,14 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
|
|
|
469
482
|
attention_mask=attention_mask,
|
|
470
483
|
cache_position=cache_position,
|
|
471
484
|
logits_to_keep=logits_to_keep,
|
|
485
|
+
is_first_iteration=is_first_iteration,
|
|
472
486
|
**kwargs,
|
|
473
487
|
)
|
|
474
488
|
|
|
475
489
|
# If we're in cached decoding stage, input_features should be None because
|
|
476
490
|
# input ids do not contain special audio token anymore Otherwise we need
|
|
477
491
|
# input feature values to be passed to the model
|
|
478
|
-
if
|
|
492
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
479
493
|
model_inputs["input_features"] = input_features
|
|
480
494
|
return model_inputs
|
|
481
495
|
|
|
@@ -80,7 +80,7 @@ class GraniteMoeRotaryEmbedding(nn.Module):
|
|
|
80
80
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
81
81
|
|
|
82
82
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
83
|
-
self.original_inv_freq =
|
|
83
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
84
84
|
|
|
85
85
|
@staticmethod
|
|
86
86
|
def compute_default_rope_parameters(
|
|
@@ -456,8 +456,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
|
|
|
456
456
|
_supports_flash_attn = True
|
|
457
457
|
_supports_sdpa = True
|
|
458
458
|
_supports_flex_attn = True
|
|
459
|
-
|
|
460
|
-
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
|
459
|
+
_can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
|
|
461
460
|
_supports_attention_backend = True
|
|
462
461
|
_can_record_outputs = {
|
|
463
462
|
"hidden_states": GraniteMoeDecoderLayer,
|
|
@@ -146,8 +146,7 @@ class GraniteMoePreTrainedModel(LlamaPreTrainedModel, PreTrainedModel):
|
|
|
146
146
|
_skip_keys_device_placement = ["past_key_values"]
|
|
147
147
|
_supports_flash_attn = True
|
|
148
148
|
_supports_sdpa = True
|
|
149
|
-
|
|
150
|
-
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
|
149
|
+
_can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
|
|
151
150
|
|
|
152
151
|
@torch.no_grad()
|
|
153
152
|
def _init_weights(self, module):
|
|
@@ -92,6 +92,8 @@ class GraniteMoeHybridConfig(PreTrainedConfig):
|
|
|
92
92
|
allow the model to output the auxiliary loss.
|
|
93
93
|
router_aux_loss_coef (`float`, *optional*, defaults to 0.001): router auxiliary loss coefficient
|
|
94
94
|
shared_intermediate_size (`int`, *optional*, defaults to 1024): intermediate size for shared experts.
|
|
95
|
+
position_embedding_type (`str`, *optional*):
|
|
96
|
+
Positional embedding type to be used; defaults to None. Allowed options: `[None, "rope"]`
|
|
95
97
|
layer_types (`List`, *optional*): list of strings to be used as layer types.
|
|
96
98
|
Allowed choices: "mamba", "attention".
|
|
97
99
|
mamba_n_heads (`int`, *optional*, defaults to 128):
|
|
@@ -159,6 +161,7 @@ class GraniteMoeHybridConfig(PreTrainedConfig):
|
|
|
159
161
|
output_router_logits: Optional[bool] = False,
|
|
160
162
|
router_aux_loss_coef: Optional[float] = 0.001,
|
|
161
163
|
shared_intermediate_size: Optional[int] = 1024,
|
|
164
|
+
position_embedding_type: Optional[str] = None,
|
|
162
165
|
layer_types: Optional[list[str]] = None,
|
|
163
166
|
mamba_n_heads: Optional[int] = 128,
|
|
164
167
|
mamba_n_groups: Optional[int] = 1,
|
|
@@ -198,6 +201,7 @@ class GraniteMoeHybridConfig(PreTrainedConfig):
|
|
|
198
201
|
self.output_router_logits = output_router_logits
|
|
199
202
|
self.router_aux_loss_coef = router_aux_loss_coef
|
|
200
203
|
self.shared_intermediate_size = shared_intermediate_size
|
|
204
|
+
self.position_embedding_type = position_embedding_type
|
|
201
205
|
self.rope_parameters = rope_parameters
|
|
202
206
|
|
|
203
207
|
mamba_intermediate = mamba_expand * hidden_size
|
|
@@ -31,7 +31,12 @@ from transformers.activations import ACT2FN
|
|
|
31
31
|
from ... import initialization as init
|
|
32
32
|
from ...cache_utils import Cache
|
|
33
33
|
from ...generation import GenerationMixin
|
|
34
|
-
from ...integrations import
|
|
34
|
+
from ...integrations import (
|
|
35
|
+
lazy_load_kernel,
|
|
36
|
+
use_kernel_forward_from_hub,
|
|
37
|
+
use_kernel_func_from_hub,
|
|
38
|
+
use_kernelized_func,
|
|
39
|
+
)
|
|
35
40
|
from ...masking_utils import create_causal_mask
|
|
36
41
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
37
42
|
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
@@ -40,22 +45,9 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
|
40
45
|
from ...processing_utils import Unpack
|
|
41
46
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
42
47
|
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
43
|
-
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
|
44
48
|
from .configuration_granitemoehybrid import GraniteMoeHybridConfig
|
|
45
49
|
|
|
46
50
|
|
|
47
|
-
if is_mamba_2_ssm_available():
|
|
48
|
-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
|
49
|
-
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
|
50
|
-
else:
|
|
51
|
-
selective_state_update = None
|
|
52
|
-
|
|
53
|
-
if is_causal_conv1d_available():
|
|
54
|
-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
55
|
-
else:
|
|
56
|
-
causal_conv1d_update, causal_conv1d_fn = None, None
|
|
57
|
-
|
|
58
|
-
|
|
59
51
|
logger = logging.get_logger(__name__)
|
|
60
52
|
|
|
61
53
|
|
|
@@ -165,6 +157,7 @@ class GraniteMoeHybridAttention(nn.Module):
|
|
|
165
157
|
attention_mask: Optional[torch.Tensor],
|
|
166
158
|
past_key_values: Optional[Cache] = None,
|
|
167
159
|
cache_position: Optional[torch.LongTensor] = None,
|
|
160
|
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
|
|
168
161
|
**kwargs: Unpack[TransformersKwargs],
|
|
169
162
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
170
163
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -174,6 +167,10 @@ class GraniteMoeHybridAttention(nn.Module):
|
|
|
174
167
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
175
168
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
176
169
|
|
|
170
|
+
if position_embeddings is not None:
|
|
171
|
+
cos, sin = position_embeddings
|
|
172
|
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
173
|
+
|
|
177
174
|
if past_key_values is not None:
|
|
178
175
|
cache_kwargs = {"cache_position": cache_position}
|
|
179
176
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
@@ -371,9 +368,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
|
|
|
371
368
|
return hidden_states
|
|
372
369
|
|
|
373
370
|
|
|
374
|
-
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
|
375
|
-
|
|
376
|
-
|
|
377
371
|
# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
|
|
378
372
|
class GraniteMoeHybridMambaLayer(nn.Module):
|
|
379
373
|
"""
|
|
@@ -445,6 +439,20 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
|
|
445
439
|
|
|
446
440
|
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
|
|
447
441
|
|
|
442
|
+
global causal_conv1d_update, causal_conv1d_fn
|
|
443
|
+
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
444
|
+
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
|
|
445
|
+
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
|
|
446
|
+
|
|
447
|
+
global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
|
448
|
+
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
|
449
|
+
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
|
|
450
|
+
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
|
|
451
|
+
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
|
|
452
|
+
|
|
453
|
+
global is_fast_path_available
|
|
454
|
+
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
|
455
|
+
|
|
448
456
|
if not is_fast_path_available:
|
|
449
457
|
logger.warning_once(
|
|
450
458
|
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
|
|
@@ -915,7 +923,7 @@ class GraniteMoeHybridRotaryEmbedding(nn.Module):
|
|
|
915
923
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
916
924
|
|
|
917
925
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
918
|
-
self.original_inv_freq =
|
|
926
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
919
927
|
|
|
920
928
|
@staticmethod
|
|
921
929
|
def compute_default_rope_parameters(
|
|
@@ -1231,8 +1239,7 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel):
|
|
|
1231
1239
|
_supports_flash_attn = True
|
|
1232
1240
|
_supports_sdpa = True
|
|
1233
1241
|
_supports_flex_attn = True
|
|
1234
|
-
|
|
1235
|
-
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
|
1242
|
+
_can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
|
|
1236
1243
|
_supports_attention_backend = True
|
|
1237
1244
|
_can_record_outputs = {
|
|
1238
1245
|
"hidden_states": GraniteMoeHybridDecoderLayer,
|
|
@@ -1265,7 +1272,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
|
|
|
1265
1272
|
[GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
1266
1273
|
)
|
|
1267
1274
|
self.norm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
1268
|
-
self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config
|
|
1275
|
+
self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config) if config.position_embedding_type == "rope" else None
|
|
1269
1276
|
self.gradient_checkpointing = False
|
|
1270
1277
|
self.embedding_multiplier = config.embedding_multiplier
|
|
1271
1278
|
|
|
@@ -1313,7 +1320,9 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
|
|
|
1313
1320
|
|
|
1314
1321
|
# embed positions
|
|
1315
1322
|
hidden_states = inputs_embeds
|
|
1316
|
-
position_embeddings =
|
|
1323
|
+
position_embeddings = None
|
|
1324
|
+
if self.rotary_emb is not None:
|
|
1325
|
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
1317
1326
|
|
|
1318
1327
|
for decoder_layer in self.layers:
|
|
1319
1328
|
# Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
|
|
@@ -1547,6 +1556,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
|
|
|
1547
1556
|
cache_position=None,
|
|
1548
1557
|
position_ids=None,
|
|
1549
1558
|
use_cache=True,
|
|
1559
|
+
is_first_iteration=False,
|
|
1550
1560
|
**kwargs,
|
|
1551
1561
|
):
|
|
1552
1562
|
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
|
|
@@ -1579,7 +1589,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
|
|
|
1579
1589
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
1580
1590
|
|
|
1581
1591
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
1582
|
-
if inputs_embeds is not None and
|
|
1592
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
1583
1593
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
1584
1594
|
else:
|
|
1585
1595
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
@@ -39,6 +39,7 @@ from ..granitemoeshared.modeling_granitemoeshared import (
|
|
|
39
39
|
GraniteMoeSharedModel,
|
|
40
40
|
GraniteMoeSharedMoE,
|
|
41
41
|
GraniteMoeSharedPreTrainedModel,
|
|
42
|
+
apply_rotary_pos_emb,
|
|
42
43
|
eager_attention_forward,
|
|
43
44
|
)
|
|
44
45
|
from .configuration_granitemoehybrid import GraniteMoeHybridConfig
|
|
@@ -57,6 +58,7 @@ class GraniteMoeHybridAttention(GraniteMoeSharedAttention):
|
|
|
57
58
|
attention_mask: Optional[torch.Tensor],
|
|
58
59
|
past_key_values: Optional[Cache] = None,
|
|
59
60
|
cache_position: Optional[torch.LongTensor] = None,
|
|
61
|
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
|
|
60
62
|
**kwargs: Unpack[TransformersKwargs],
|
|
61
63
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
62
64
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -66,6 +68,10 @@ class GraniteMoeHybridAttention(GraniteMoeSharedAttention):
|
|
|
66
68
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
67
69
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
68
70
|
|
|
71
|
+
if position_embeddings is not None:
|
|
72
|
+
cos, sin = position_embeddings
|
|
73
|
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
74
|
+
|
|
69
75
|
if past_key_values is not None:
|
|
70
76
|
cache_kwargs = {"cache_position": cache_position}
|
|
71
77
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
@@ -203,6 +209,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
|
|
|
203
209
|
[GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
204
210
|
)
|
|
205
211
|
self.embedding_multiplier = config.embedding_multiplier
|
|
212
|
+
self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config) if config.position_embedding_type == "rope" else None
|
|
206
213
|
|
|
207
214
|
@auto_docstring
|
|
208
215
|
@check_model_inputs
|
|
@@ -245,7 +252,9 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
|
|
|
245
252
|
|
|
246
253
|
# embed positions
|
|
247
254
|
hidden_states = inputs_embeds
|
|
248
|
-
position_embeddings =
|
|
255
|
+
position_embeddings = None
|
|
256
|
+
if self.rotary_emb is not None:
|
|
257
|
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
249
258
|
|
|
250
259
|
for decoder_layer in self.layers:
|
|
251
260
|
# Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
|
|
@@ -300,6 +309,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM):
|
|
|
300
309
|
cache_position=None,
|
|
301
310
|
position_ids=None,
|
|
302
311
|
use_cache=True,
|
|
312
|
+
is_first_iteration=False,
|
|
303
313
|
**kwargs,
|
|
304
314
|
):
|
|
305
315
|
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
|
|
@@ -332,7 +342,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM):
|
|
|
332
342
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
333
343
|
|
|
334
344
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
335
|
-
if inputs_embeds is not None and
|
|
345
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
336
346
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
337
347
|
else:
|
|
338
348
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
@@ -462,8 +462,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel):
|
|
|
462
462
|
_supports_flash_attn = True
|
|
463
463
|
_supports_sdpa = True
|
|
464
464
|
_supports_flex_attn = True
|
|
465
|
-
|
|
466
|
-
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
|
465
|
+
_can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
|
|
467
466
|
_supports_attention_backend = True
|
|
468
467
|
_can_record_outputs = {
|
|
469
468
|
"hidden_states": GraniteMoeSharedDecoderLayer,
|
|
@@ -494,7 +493,7 @@ class GraniteMoeSharedRotaryEmbedding(nn.Module):
|
|
|
494
493
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
495
494
|
|
|
496
495
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
497
|
-
self.original_inv_freq =
|
|
496
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
498
497
|
|
|
499
498
|
@staticmethod
|
|
500
499
|
def compute_default_rope_parameters(
|
|
@@ -34,7 +34,7 @@ class GroundingDinoConfig(PreTrainedConfig):
|
|
|
34
34
|
documentation from [`PreTrainedConfig`] for more information.
|
|
35
35
|
|
|
36
36
|
Args:
|
|
37
|
-
backbone_config (`PreTrainedConfig
|
|
37
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `SwinConfig()`):
|
|
38
38
|
The configuration of the backbone model.
|
|
39
39
|
backbone (`str`, *optional*):
|
|
40
40
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -285,9 +285,8 @@ class GroundingDinoConfig(PreTrainedConfig):
|
|
|
285
285
|
self.positional_embedding_temperature = positional_embedding_temperature
|
|
286
286
|
self.init_std = init_std
|
|
287
287
|
self.layer_norm_eps = layer_norm_eps
|
|
288
|
+
|
|
288
289
|
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
|
289
|
-
self.tie_encoder_decoder = True
|
|
290
|
-
self.tie_encoder_decoder = True
|
|
291
290
|
|
|
292
291
|
|
|
293
292
|
__all__ = ["GroundingDinoConfig"]
|
|
@@ -1415,7 +1415,7 @@ class GroundingDinoPreTrainedModel(PreTrainedModel):
|
|
|
1415
1415
|
elif isinstance(module, GroundingDinoFusionLayer):
|
|
1416
1416
|
init.constant_(module.vision_param, 1e-4)
|
|
1417
1417
|
init.constant_(module.text_param, 1e-4)
|
|
1418
|
-
elif isinstance(module, (nn.Linear, nn.Conv2d
|
|
1418
|
+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
1419
1419
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
1420
1420
|
if module.bias is not None:
|
|
1421
1421
|
init.zeros_(module.bias)
|
|
@@ -1511,7 +1511,7 @@ class GroundingDinoEncoder(GroundingDinoPreTrainedModel):
|
|
|
1511
1511
|
output_hidden_states=None,
|
|
1512
1512
|
return_dict=None,
|
|
1513
1513
|
**kwargs,
|
|
1514
|
-
):
|
|
1514
|
+
) -> Union[tuple, GroundingDinoEncoderOutput]:
|
|
1515
1515
|
r"""
|
|
1516
1516
|
Args:
|
|
1517
1517
|
vision_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
@@ -1666,7 +1666,7 @@ class GroundingDinoDecoder(GroundingDinoPreTrainedModel):
|
|
|
1666
1666
|
output_hidden_states=None,
|
|
1667
1667
|
return_dict=None,
|
|
1668
1668
|
**kwargs,
|
|
1669
|
-
):
|
|
1669
|
+
) -> Union[tuple, GroundingDinoDecoderOutput]:
|
|
1670
1670
|
r"""
|
|
1671
1671
|
Args:
|
|
1672
1672
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
@@ -2059,7 +2059,7 @@ class GroundingDinoModel(GroundingDinoPreTrainedModel):
|
|
|
2059
2059
|
output_hidden_states=None,
|
|
2060
2060
|
return_dict=None,
|
|
2061
2061
|
**kwargs,
|
|
2062
|
-
):
|
|
2062
|
+
) -> Union[tuple, GroundingDinoModelOutput]:
|
|
2063
2063
|
r"""
|
|
2064
2064
|
input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`):
|
|
2065
2065
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
@@ -758,14 +758,19 @@ class GroupViTPreTrainedModel(PreTrainedModel):
|
|
|
758
758
|
init.normal_(module.weight, mean=0.0, std=init_range)
|
|
759
759
|
if module.bias is not None:
|
|
760
760
|
init.zeros_(module.bias)
|
|
761
|
-
elif isinstance(module, nn.LayerNorm):
|
|
761
|
+
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
|
|
762
762
|
init.zeros_(module.bias)
|
|
763
763
|
init.ones_(module.weight)
|
|
764
|
+
if getattr(module, "running_mean", None) is not None:
|
|
765
|
+
init.zeros_(module.running_mean)
|
|
766
|
+
init.ones_(module.running_var)
|
|
767
|
+
init.zeros_(module.num_batches_tracked)
|
|
764
768
|
|
|
765
769
|
factor = self.config.initializer_factor
|
|
766
770
|
if isinstance(module, GroupViTTextEmbeddings):
|
|
767
771
|
init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
|
|
768
772
|
init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
|
|
773
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
769
774
|
elif isinstance(module, GroupViTAttention):
|
|
770
775
|
factor = self.config.initializer_factor
|
|
771
776
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
@@ -79,7 +79,7 @@ class HeliumRotaryEmbedding(nn.Module):
|
|
|
79
79
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
80
80
|
|
|
81
81
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
82
|
-
self.original_inv_freq =
|
|
82
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
83
83
|
|
|
84
84
|
@staticmethod
|
|
85
85
|
def compute_default_rope_parameters(
|
|
@@ -26,6 +26,7 @@ import torch
|
|
|
26
26
|
import torch.nn.functional as F
|
|
27
27
|
from torch import Tensor, nn
|
|
28
28
|
|
|
29
|
+
from ... import initialization as init
|
|
29
30
|
from ...activations import ACT2FN
|
|
30
31
|
from ...modeling_outputs import BackboneOutput, BaseModelOutputWithNoAttention, ImageClassifierOutputWithNoAttention
|
|
31
32
|
from ...modeling_utils import PreTrainedModel
|
|
@@ -45,6 +46,15 @@ class HGNetV2PreTrainedModel(PreTrainedModel):
|
|
|
45
46
|
input_modalities = ("image",)
|
|
46
47
|
_no_split_modules = ["HGNetV2BasicLayer"]
|
|
47
48
|
|
|
49
|
+
def _init_weights(self, module):
|
|
50
|
+
super()._init_weights(module)
|
|
51
|
+
# We need to check it like that as d_fine models replace the BatchNorm2d by their own
|
|
52
|
+
if "BatchNorm" in module.__class__.__name__:
|
|
53
|
+
init.ones_(module.weight)
|
|
54
|
+
init.zeros_(module.bias)
|
|
55
|
+
init.zeros_(module.running_mean)
|
|
56
|
+
init.ones_(module.running_var)
|
|
57
|
+
|
|
48
58
|
|
|
49
59
|
class HGNetV2LearnableAffineBlock(nn.Module):
|
|
50
60
|
def __init__(self, scale_value: float = 1.0, bias_value: float = 0.0):
|
|
@@ -20,6 +20,7 @@ import torch
|
|
|
20
20
|
import torch.nn.functional as F
|
|
21
21
|
from torch import Tensor, nn
|
|
22
22
|
|
|
23
|
+
from ... import initialization as init
|
|
23
24
|
from ...configuration_utils import PreTrainedConfig
|
|
24
25
|
from ...modeling_outputs import (
|
|
25
26
|
BackboneOutput,
|
|
@@ -170,6 +171,15 @@ class HGNetV2PreTrainedModel(PreTrainedModel):
|
|
|
170
171
|
input_modalities = ("image",)
|
|
171
172
|
_no_split_modules = ["HGNetV2BasicLayer"]
|
|
172
173
|
|
|
174
|
+
def _init_weights(self, module):
|
|
175
|
+
super()._init_weights(module)
|
|
176
|
+
# We need to check it like that as d_fine models replace the BatchNorm2d by their own
|
|
177
|
+
if "BatchNorm" in module.__class__.__name__:
|
|
178
|
+
init.ones_(module.weight)
|
|
179
|
+
init.zeros_(module.bias)
|
|
180
|
+
init.zeros_(module.running_mean)
|
|
181
|
+
init.ones_(module.running_var)
|
|
182
|
+
|
|
173
183
|
|
|
174
184
|
class HGNetV2LearnableAffineBlock(nn.Module):
|
|
175
185
|
def __init__(self, scale_value: float = 1.0, bias_value: float = 0.0):
|
|
@@ -648,6 +648,10 @@ class HubertPreTrainedModel(PreTrainedModel):
|
|
|
648
648
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
|
|
649
649
|
init.zeros_(module.bias)
|
|
650
650
|
init.ones_(module.weight)
|
|
651
|
+
if getattr(module, "running_mean", None) is not None:
|
|
652
|
+
init.zeros_(module.running_mean)
|
|
653
|
+
init.ones_(module.running_var)
|
|
654
|
+
init.zeros_(module.num_batches_tracked)
|
|
651
655
|
elif isinstance(module, nn.Conv1d):
|
|
652
656
|
if is_deepspeed_zero3_enabled():
|
|
653
657
|
import deepspeed
|
|
@@ -145,6 +145,10 @@ class HubertPreTrainedModel(PreTrainedModel):
|
|
|
145
145
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
|
|
146
146
|
init.zeros_(module.bias)
|
|
147
147
|
init.ones_(module.weight)
|
|
148
|
+
if getattr(module, "running_mean", None) is not None:
|
|
149
|
+
init.zeros_(module.running_mean)
|
|
150
|
+
init.ones_(module.running_var)
|
|
151
|
+
init.zeros_(module.num_batches_tracked)
|
|
148
152
|
elif isinstance(module, nn.Conv1d):
|
|
149
153
|
if is_deepspeed_zero3_enabled():
|
|
150
154
|
import deepspeed
|
|
@@ -320,7 +320,7 @@ class HunYuanDenseV1RotaryEmbedding(nn.Module):
|
|
|
320
320
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
321
321
|
|
|
322
322
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
323
|
-
self.original_inv_freq =
|
|
323
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
324
324
|
|
|
325
325
|
@staticmethod
|
|
326
326
|
def compute_default_rope_parameters(
|
|
@@ -148,7 +148,7 @@ class HunYuanDenseV1RotaryEmbedding(LlamaRotaryEmbedding):
|
|
|
148
148
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
149
149
|
|
|
150
150
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
151
|
-
self.original_inv_freq =
|
|
151
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
152
152
|
|
|
153
153
|
|
|
154
154
|
class HunYuanDenseV1Model(LlamaModel):
|
|
@@ -30,14 +30,19 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import
|
|
33
|
+
from ...integrations import (
|
|
34
|
+
use_experts_implementation,
|
|
35
|
+
use_kernel_forward_from_hub,
|
|
36
|
+
use_kernel_func_from_hub,
|
|
37
|
+
use_kernelized_func,
|
|
38
|
+
)
|
|
34
39
|
from ...masking_utils import create_causal_mask
|
|
35
40
|
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
|
|
36
41
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
37
42
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
38
43
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
44
|
from ...processing_utils import Unpack
|
|
40
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
45
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
41
46
|
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
42
47
|
from .configuration_hunyuan_v1_moe import HunYuanMoEV1Config
|
|
43
48
|
|
|
@@ -244,6 +249,7 @@ class HunYuanMoEV1Gate(nn.Module):
|
|
|
244
249
|
return logits
|
|
245
250
|
|
|
246
251
|
|
|
252
|
+
@use_experts_implementation
|
|
247
253
|
class HunYuanMoEV1Experts(nn.Module):
|
|
248
254
|
"""Collection of expert weights stored as 3D tensors."""
|
|
249
255
|
|
|
@@ -371,7 +377,9 @@ class HunYuanMoEV1PreTrainedModel(PreTrainedModel):
|
|
|
371
377
|
_supports_flash_attn = True
|
|
372
378
|
_supports_sdpa = True
|
|
373
379
|
_supports_flex_attn = True
|
|
374
|
-
_can_compile_fullgraph =
|
|
380
|
+
_can_compile_fullgraph = (
|
|
381
|
+
is_grouped_mm_available()
|
|
382
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
375
383
|
_supports_attention_backend = True
|
|
376
384
|
_can_record_outputs = {
|
|
377
385
|
"hidden_states": HunYuanMoEV1DecoderLayer,
|
|
@@ -413,7 +421,7 @@ class HunYuanMoEV1RotaryEmbedding(nn.Module):
|
|
|
413
421
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
414
422
|
|
|
415
423
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
416
|
-
self.original_inv_freq =
|
|
424
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
417
425
|
|
|
418
426
|
@staticmethod
|
|
419
427
|
def compute_default_rope_parameters(
|