transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -22,6 +22,7 @@ import torch
|
|
|
22
22
|
from torch import nn
|
|
23
23
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
24
24
|
|
|
25
|
+
from ... import initialization as init
|
|
25
26
|
from ...cache_utils import Cache, DynamicCache
|
|
26
27
|
from ...generation import GenerationMixin
|
|
27
28
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
|
|
@@ -187,6 +188,13 @@ class CTRLPreTrainedModel(PreTrainedModel):
|
|
|
187
188
|
config: CTRLConfig
|
|
188
189
|
base_model_prefix = "transformer"
|
|
189
190
|
|
|
191
|
+
def _init_weights(self, module):
|
|
192
|
+
super()._init_weights(module)
|
|
193
|
+
if isinstance(module, CTRLModel):
|
|
194
|
+
init.copy_(
|
|
195
|
+
module.pos_encoding, positional_encoding(module.config.n_positions, module.d_model_size, torch.float)
|
|
196
|
+
)
|
|
197
|
+
|
|
190
198
|
|
|
191
199
|
@auto_docstring
|
|
192
200
|
class CTRLModel(CTRLPreTrainedModel):
|
|
@@ -196,7 +204,9 @@ class CTRLModel(CTRLPreTrainedModel):
|
|
|
196
204
|
self.d_model_size = config.n_embd
|
|
197
205
|
self.num_layers = config.n_layer
|
|
198
206
|
|
|
199
|
-
self.
|
|
207
|
+
self.register_buffer(
|
|
208
|
+
"pos_encoding", positional_encoding(config.n_positions, self.d_model_size, torch.float), persistent=False
|
|
209
|
+
)
|
|
200
210
|
|
|
201
211
|
self.w = nn.Embedding(config.vocab_size, config.n_embd)
|
|
202
212
|
|
|
@@ -470,7 +480,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel, GenerationMixin):
|
|
|
470
480
|
attentions=transformer_outputs.attentions,
|
|
471
481
|
)
|
|
472
482
|
|
|
473
|
-
def prepare_inputs_for_generation(
|
|
483
|
+
def prepare_inputs_for_generation(
|
|
484
|
+
self, input_ids, past_key_values=None, use_cache=None, is_first_iteration=False, **kwargs
|
|
485
|
+
):
|
|
474
486
|
# Overwritten -- inputs_embeds not working properly
|
|
475
487
|
|
|
476
488
|
# only last tokens for inputs_ids if past is defined in kwargs
|
|
@@ -497,9 +497,13 @@ class CvtPreTrainedModel(PreTrainedModel):
|
|
|
497
497
|
init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
498
498
|
if module.bias is not None:
|
|
499
499
|
init.zeros_(module.bias)
|
|
500
|
-
elif isinstance(module, nn.LayerNorm):
|
|
500
|
+
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
|
|
501
501
|
init.zeros_(module.bias)
|
|
502
502
|
init.ones_(module.weight)
|
|
503
|
+
if getattr(module, "running_mean", None) is not None:
|
|
504
|
+
init.zeros_(module.running_mean)
|
|
505
|
+
init.ones_(module.running_var)
|
|
506
|
+
init.zeros_(module.num_batches_tracked)
|
|
503
507
|
elif isinstance(module, CvtStage):
|
|
504
508
|
if self.config.cls_token[module.stage]:
|
|
505
509
|
init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
|
|
@@ -58,7 +58,7 @@ class CwmRotaryEmbedding(nn.Module):
|
|
|
58
58
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
59
59
|
|
|
60
60
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
61
|
-
self.original_inv_freq =
|
|
61
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
62
62
|
|
|
63
63
|
@staticmethod
|
|
64
64
|
def compute_default_rope_parameters(
|
|
@@ -47,7 +47,7 @@ class DFineConfig(PreTrainedConfig):
|
|
|
47
47
|
The epsilon used by the layer normalization layers.
|
|
48
48
|
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
|
49
49
|
The epsilon used by the batch normalization layers.
|
|
50
|
-
backbone_config (`
|
|
50
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`):
|
|
51
51
|
The configuration of the backbone model.
|
|
52
52
|
backbone (`str`, *optional*):
|
|
53
53
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -288,8 +288,7 @@ class DFineConfig(PreTrainedConfig):
|
|
|
288
288
|
)
|
|
289
289
|
backbone_model_type = "hgnet_v2"
|
|
290
290
|
config_class = CONFIG_MAPPING[backbone_model_type]
|
|
291
|
-
# this will map it to
|
|
292
|
-
# note: we can instead create HGNetV2Config
|
|
291
|
+
# this will map it to HGNetV2Config
|
|
293
292
|
# and we would need to create HGNetV2Backbone
|
|
294
293
|
backbone_config = config_class(
|
|
295
294
|
num_channels=3,
|
|
@@ -395,8 +394,8 @@ class DFineConfig(PreTrainedConfig):
|
|
|
395
394
|
raise ValueError(
|
|
396
395
|
f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
|
|
397
396
|
)
|
|
397
|
+
|
|
398
398
|
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
|
399
|
-
self.tie_encoder_decoder = True
|
|
400
399
|
|
|
401
400
|
|
|
402
401
|
__all__ = ["DFineConfig"]
|
|
@@ -483,6 +483,9 @@ class DFinePreTrainedModel(PreTrainedModel):
|
|
|
483
483
|
init.constant_(module.attention_weights.weight, 0.0)
|
|
484
484
|
init.constant_(module.attention_weights.bias, 0.0)
|
|
485
485
|
|
|
486
|
+
num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)]
|
|
487
|
+
init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32))
|
|
488
|
+
|
|
486
489
|
if isinstance(module, DFineModel):
|
|
487
490
|
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
|
488
491
|
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
|
@@ -493,6 +496,10 @@ class DFinePreTrainedModel(PreTrainedModel):
|
|
|
493
496
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
494
497
|
if module.bias is not None:
|
|
495
498
|
init.zeros_(module.bias)
|
|
499
|
+
if getattr(module, "running_mean", None) is not None:
|
|
500
|
+
init.zeros_(module.running_mean)
|
|
501
|
+
init.ones_(module.running_var)
|
|
502
|
+
init.zeros_(module.num_batches_tracked)
|
|
496
503
|
|
|
497
504
|
if isinstance(module, DFineGate):
|
|
498
505
|
bias = float(-math.log((1 - 0.5) / 0.5))
|
|
@@ -838,6 +845,45 @@ class DFineDecoder(DFinePreTrainedModel):
|
|
|
838
845
|
)
|
|
839
846
|
|
|
840
847
|
|
|
848
|
+
class DFineFrozenBatchNorm2d(nn.Module):
|
|
849
|
+
"""
|
|
850
|
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
851
|
+
|
|
852
|
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
|
|
853
|
+
torchvision.models.resnet[18,34,50,101] produce nans.
|
|
854
|
+
"""
|
|
855
|
+
|
|
856
|
+
def __init__(self, n):
|
|
857
|
+
super().__init__()
|
|
858
|
+
self.register_buffer("weight", torch.ones(n))
|
|
859
|
+
self.register_buffer("bias", torch.zeros(n))
|
|
860
|
+
self.register_buffer("running_mean", torch.zeros(n))
|
|
861
|
+
self.register_buffer("running_var", torch.ones(n))
|
|
862
|
+
|
|
863
|
+
def _load_from_state_dict(
|
|
864
|
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
865
|
+
):
|
|
866
|
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
|
867
|
+
if num_batches_tracked_key in state_dict:
|
|
868
|
+
del state_dict[num_batches_tracked_key]
|
|
869
|
+
|
|
870
|
+
super()._load_from_state_dict(
|
|
871
|
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
def forward(self, x):
|
|
875
|
+
# move reshapes to the beginning
|
|
876
|
+
# to make it user-friendly
|
|
877
|
+
weight = self.weight.reshape(1, -1, 1, 1)
|
|
878
|
+
bias = self.bias.reshape(1, -1, 1, 1)
|
|
879
|
+
running_var = self.running_var.reshape(1, -1, 1, 1)
|
|
880
|
+
running_mean = self.running_mean.reshape(1, -1, 1, 1)
|
|
881
|
+
epsilon = 1e-5
|
|
882
|
+
scale = weight * (running_var + epsilon).rsqrt()
|
|
883
|
+
bias = bias - running_mean * scale
|
|
884
|
+
return x * scale + bias
|
|
885
|
+
|
|
886
|
+
|
|
841
887
|
@dataclass
|
|
842
888
|
@auto_docstring(
|
|
843
889
|
custom_intro="""
|
|
@@ -896,45 +942,6 @@ class DFineModelOutput(ModelOutput):
|
|
|
896
942
|
denoising_meta_values: Optional[dict] = None
|
|
897
943
|
|
|
898
944
|
|
|
899
|
-
class DFineFrozenBatchNorm2d(nn.Module):
|
|
900
|
-
"""
|
|
901
|
-
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
902
|
-
|
|
903
|
-
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
|
|
904
|
-
torchvision.models.resnet[18,34,50,101] produce nans.
|
|
905
|
-
"""
|
|
906
|
-
|
|
907
|
-
def __init__(self, n):
|
|
908
|
-
super().__init__()
|
|
909
|
-
self.register_buffer("weight", torch.ones(n))
|
|
910
|
-
self.register_buffer("bias", torch.zeros(n))
|
|
911
|
-
self.register_buffer("running_mean", torch.zeros(n))
|
|
912
|
-
self.register_buffer("running_var", torch.ones(n))
|
|
913
|
-
|
|
914
|
-
def _load_from_state_dict(
|
|
915
|
-
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
916
|
-
):
|
|
917
|
-
num_batches_tracked_key = prefix + "num_batches_tracked"
|
|
918
|
-
if num_batches_tracked_key in state_dict:
|
|
919
|
-
del state_dict[num_batches_tracked_key]
|
|
920
|
-
|
|
921
|
-
super()._load_from_state_dict(
|
|
922
|
-
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
923
|
-
)
|
|
924
|
-
|
|
925
|
-
def forward(self, x):
|
|
926
|
-
# move reshapes to the beginning
|
|
927
|
-
# to make it user-friendly
|
|
928
|
-
weight = self.weight.reshape(1, -1, 1, 1)
|
|
929
|
-
bias = self.bias.reshape(1, -1, 1, 1)
|
|
930
|
-
running_var = self.running_var.reshape(1, -1, 1, 1)
|
|
931
|
-
running_mean = self.running_mean.reshape(1, -1, 1, 1)
|
|
932
|
-
epsilon = 1e-5
|
|
933
|
-
scale = weight * (running_var + epsilon).rsqrt()
|
|
934
|
-
bias = bias - running_mean * scale
|
|
935
|
-
return x * scale + bias
|
|
936
|
-
|
|
937
|
-
|
|
938
945
|
def replace_batch_norm(model):
|
|
939
946
|
r"""
|
|
940
947
|
Recursively replace all `torch.nn.BatchNorm2d` with `DFineFrozenBatchNorm2d`.
|
|
@@ -33,6 +33,7 @@ from ..rt_detr.modeling_rt_detr import (
|
|
|
33
33
|
RTDetrDecoderOutput,
|
|
34
34
|
RTDetrEncoder,
|
|
35
35
|
RTDetrForObjectDetection,
|
|
36
|
+
RTDetrFrozenBatchNorm2d,
|
|
36
37
|
RTDetrHybridEncoder,
|
|
37
38
|
RTDetrMLPPredictionHead,
|
|
38
39
|
RTDetrModel,
|
|
@@ -66,7 +67,7 @@ class DFineConfig(PreTrainedConfig):
|
|
|
66
67
|
The epsilon used by the layer normalization layers.
|
|
67
68
|
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
|
68
69
|
The epsilon used by the batch normalization layers.
|
|
69
|
-
backbone_config (`
|
|
70
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`):
|
|
70
71
|
The configuration of the backbone model.
|
|
71
72
|
backbone (`str`, *optional*):
|
|
72
73
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -307,8 +308,7 @@ class DFineConfig(PreTrainedConfig):
|
|
|
307
308
|
)
|
|
308
309
|
backbone_model_type = "hgnet_v2"
|
|
309
310
|
config_class = CONFIG_MAPPING[backbone_model_type]
|
|
310
|
-
# this will map it to
|
|
311
|
-
# note: we can instead create HGNetV2Config
|
|
311
|
+
# this will map it to HGNetV2Config
|
|
312
312
|
# and we would need to create HGNetV2Backbone
|
|
313
313
|
backbone_config = config_class(
|
|
314
314
|
num_channels=3,
|
|
@@ -414,8 +414,8 @@ class DFineConfig(PreTrainedConfig):
|
|
|
414
414
|
raise ValueError(
|
|
415
415
|
f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
|
|
416
416
|
)
|
|
417
|
+
|
|
417
418
|
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
|
418
|
-
self.tie_encoder_decoder = True
|
|
419
419
|
|
|
420
420
|
|
|
421
421
|
class DFineMultiscaleDeformableAttention(nn.Module):
|
|
@@ -628,6 +628,9 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel):
|
|
|
628
628
|
init.constant_(module.attention_weights.weight, 0.0)
|
|
629
629
|
init.constant_(module.attention_weights.bias, 0.0)
|
|
630
630
|
|
|
631
|
+
num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)]
|
|
632
|
+
init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32))
|
|
633
|
+
|
|
631
634
|
if isinstance(module, DFineModel):
|
|
632
635
|
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
|
633
636
|
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
|
@@ -638,6 +641,10 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel):
|
|
|
638
641
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
639
642
|
if module.bias is not None:
|
|
640
643
|
init.zeros_(module.bias)
|
|
644
|
+
if getattr(module, "running_mean", None) is not None:
|
|
645
|
+
init.zeros_(module.running_mean)
|
|
646
|
+
init.ones_(module.running_var)
|
|
647
|
+
init.zeros_(module.num_batches_tracked)
|
|
641
648
|
|
|
642
649
|
if isinstance(module, DFineGate):
|
|
643
650
|
bias = float(-math.log((1 - 0.5) / 0.5))
|
|
@@ -851,6 +858,10 @@ class DFineDecoder(RTDetrDecoder):
|
|
|
851
858
|
)
|
|
852
859
|
|
|
853
860
|
|
|
861
|
+
class DFineFrozenBatchNorm2d(RTDetrFrozenBatchNorm2d):
|
|
862
|
+
pass
|
|
863
|
+
|
|
864
|
+
|
|
854
865
|
class DFineModel(RTDetrModel):
|
|
855
866
|
def __init__(self, config: DFineConfig):
|
|
856
867
|
super().__init__(config)
|
|
@@ -37,7 +37,7 @@ class DabDetrConfig(PreTrainedConfig):
|
|
|
37
37
|
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
|
38
38
|
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
|
|
39
39
|
API.
|
|
40
|
-
backbone_config (`PreTrainedConfig
|
|
40
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `ResNetConfig()`):
|
|
41
41
|
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
|
|
42
42
|
case it will default to `ResNetConfig()`.
|
|
43
43
|
backbone (`str`, *optional*, defaults to `"resnet50"`):
|
|
@@ -255,8 +255,8 @@ class DabDetrConfig(PreTrainedConfig):
|
|
|
255
255
|
self.temperature_height = temperature_height
|
|
256
256
|
self.sine_position_embedding_scale = sine_position_embedding_scale
|
|
257
257
|
self.initializer_bias_prior_prob = initializer_bias_prior_prob
|
|
258
|
+
|
|
258
259
|
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
|
259
|
-
self.tie_encoder_decoder = True # weights have to be tied for this model
|
|
260
260
|
|
|
261
261
|
|
|
262
262
|
__all__ = ["DabDetrConfig"]
|
|
@@ -826,7 +826,7 @@ class DabDetrPreTrainedModel(PreTrainedModel):
|
|
|
826
826
|
init.zeros_(module.q_linear.bias)
|
|
827
827
|
init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
|
|
828
828
|
init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
|
|
829
|
-
if isinstance(module, (nn.Linear, nn.Conv2d
|
|
829
|
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
830
830
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
831
831
|
if module.bias is not None:
|
|
832
832
|
init.zeros_(module.bias)
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
import math
|
|
18
18
|
from dataclasses import dataclass
|
|
19
|
-
from typing import Optional
|
|
19
|
+
from typing import Optional, Union
|
|
20
20
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
import torch
|
|
@@ -583,7 +583,7 @@ class DacModel(DacPreTrainedModel):
|
|
|
583
583
|
input_values: torch.Tensor,
|
|
584
584
|
n_quantizers: Optional[int] = None,
|
|
585
585
|
return_dict: Optional[bool] = None,
|
|
586
|
-
):
|
|
586
|
+
) -> Union[tuple, DacEncoderOutput]:
|
|
587
587
|
r"""
|
|
588
588
|
input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
|
|
589
589
|
Input audio data to encode,
|
|
@@ -610,7 +610,7 @@ class DacModel(DacPreTrainedModel):
|
|
|
610
610
|
quantized_representation: Optional[torch.Tensor] = None,
|
|
611
611
|
audio_codes: Optional[torch.Tensor] = None,
|
|
612
612
|
return_dict: Optional[bool] = None,
|
|
613
|
-
):
|
|
613
|
+
) -> Union[tuple, DacDecoderOutput]:
|
|
614
614
|
r"""
|
|
615
615
|
quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
|
|
616
616
|
Quantized continuous representation of input.
|
|
@@ -643,7 +643,7 @@ class DacModel(DacPreTrainedModel):
|
|
|
643
643
|
input_values: torch.Tensor,
|
|
644
644
|
n_quantizers: Optional[int] = None,
|
|
645
645
|
return_dict: Optional[bool] = None,
|
|
646
|
-
):
|
|
646
|
+
) -> Union[tuple, DacOutput]:
|
|
647
647
|
r"""
|
|
648
648
|
input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`):
|
|
649
649
|
Audio data to encode.
|
|
@@ -26,6 +26,7 @@ import torch
|
|
|
26
26
|
import torch.nn as nn
|
|
27
27
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
28
28
|
|
|
29
|
+
from ... import initialization as init
|
|
29
30
|
from ...activations import ACT2FN, gelu
|
|
30
31
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
31
32
|
from ...generation import GenerationMixin
|
|
@@ -494,6 +495,12 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
|
|
|
494
495
|
"cross_attentions": Data2VecTextCrossAttention,
|
|
495
496
|
}
|
|
496
497
|
|
|
498
|
+
def _init_weights(self, module):
|
|
499
|
+
super()._init_weights(module)
|
|
500
|
+
if isinstance(module, Data2VecTextEmbeddings):
|
|
501
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
502
|
+
init.zeros_(module.token_type_ids)
|
|
503
|
+
|
|
497
504
|
|
|
498
505
|
class Data2VecTextEncoder(nn.Module):
|
|
499
506
|
def __init__(self, config):
|
|
@@ -20,6 +20,7 @@ import torch
|
|
|
20
20
|
import torch.nn as nn
|
|
21
21
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
22
22
|
|
|
23
|
+
from ... import initialization as init
|
|
23
24
|
from ...generation import GenerationMixin
|
|
24
25
|
from ...modeling_outputs import (
|
|
25
26
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
@@ -81,6 +82,12 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
|
|
|
81
82
|
"cross_attentions": Data2VecTextCrossAttention,
|
|
82
83
|
}
|
|
83
84
|
|
|
85
|
+
def _init_weights(self, module):
|
|
86
|
+
super()._init_weights(module)
|
|
87
|
+
if isinstance(module, Data2VecTextEmbeddings):
|
|
88
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
89
|
+
init.zeros_(module.token_type_ids)
|
|
90
|
+
|
|
84
91
|
|
|
85
92
|
@auto_docstring
|
|
86
93
|
class Data2VecTextModel(RobertaModel):
|
|
@@ -104,7 +104,15 @@ class DbrxFFNConfig(PreTrainedConfig):
|
|
|
104
104
|
self.moe_loss_weight = moe_loss_weight
|
|
105
105
|
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
|
106
106
|
|
|
107
|
-
for k in [
|
|
107
|
+
for k in [
|
|
108
|
+
"model_type",
|
|
109
|
+
"attn_implementation",
|
|
110
|
+
"experts_implementation",
|
|
111
|
+
"transformers_version",
|
|
112
|
+
"_commit_hash",
|
|
113
|
+
"torch_dtype",
|
|
114
|
+
"dtype",
|
|
115
|
+
]:
|
|
108
116
|
if k in kwargs:
|
|
109
117
|
kwargs.pop(k)
|
|
110
118
|
if len(kwargs) != 0:
|
|
@@ -58,7 +58,7 @@ class DbrxRotaryEmbedding(nn.Module):
|
|
|
58
58
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
59
59
|
|
|
60
60
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
61
|
-
self.original_inv_freq =
|
|
61
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
62
62
|
|
|
63
63
|
@staticmethod
|
|
64
64
|
def compute_default_rope_parameters(
|
|
@@ -624,6 +624,8 @@ class DebertaPreTrainedModel(PreTrainedModel):
|
|
|
624
624
|
init.zeros_(module.v_bias)
|
|
625
625
|
elif isinstance(module, (LegacyDebertaLMPredictionHead, DebertaLMPredictionHead)):
|
|
626
626
|
init.zeros_(module.bias)
|
|
627
|
+
elif isinstance(module, DebertaEmbeddings):
|
|
628
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
627
629
|
|
|
628
630
|
|
|
629
631
|
@auto_docstring
|
|
@@ -700,6 +700,8 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
|
|
|
700
700
|
super()._init_weights(module)
|
|
701
701
|
if isinstance(module, (LegacyDebertaV2LMPredictionHead, DebertaV2LMPredictionHead)):
|
|
702
702
|
init.zeros_(module.bias)
|
|
703
|
+
elif isinstance(module, DebertaV2Embeddings):
|
|
704
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
703
705
|
|
|
704
706
|
|
|
705
707
|
@auto_docstring
|
|
@@ -94,7 +94,6 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
|
|
94
94
|
),
|
|
95
95
|
persistent=False,
|
|
96
96
|
)
|
|
97
|
-
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
|
|
98
97
|
|
|
99
98
|
self.embed_dim = config.hidden_size
|
|
100
99
|
self.num_heads = config.num_attention_heads
|
|
@@ -367,12 +366,8 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
|
|
|
367
366
|
config: DecisionTransformerConfig
|
|
368
367
|
base_model_prefix = "transformer"
|
|
369
368
|
supports_gradient_checkpointing = True
|
|
370
|
-
|
|
371
369
|
_can_compile_fullgraph = False
|
|
372
370
|
|
|
373
|
-
def __init__(self, *inputs, **kwargs):
|
|
374
|
-
super().__init__(*inputs, **kwargs)
|
|
375
|
-
|
|
376
371
|
@torch.no_grad()
|
|
377
372
|
def _init_weights(self, module):
|
|
378
373
|
"""Initialize the weights."""
|
|
@@ -389,6 +384,14 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
|
|
|
389
384
|
if "c_proj" in name and "weight" in name:
|
|
390
385
|
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
|
391
386
|
init.normal_(p, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
|
|
387
|
+
elif isinstance(module, DecisionTransformerGPT2Attention):
|
|
388
|
+
max_positions = module.config.max_position_embeddings
|
|
389
|
+
init.copy_(
|
|
390
|
+
module.bias,
|
|
391
|
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
|
392
|
+
1, 1, max_positions, max_positions
|
|
393
|
+
),
|
|
394
|
+
)
|
|
392
395
|
|
|
393
396
|
|
|
394
397
|
class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
|
@@ -30,18 +30,19 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
33
|
+
from ...integrations import use_experts_implementation, use_kernel_forward_from_hub
|
|
34
34
|
from ...masking_utils import create_causal_mask
|
|
35
35
|
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
|
|
36
36
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
37
37
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
38
38
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
40
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
41
41
|
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
42
42
|
from .configuration_deepseek_v2 import DeepseekV2Config
|
|
43
43
|
|
|
44
44
|
|
|
45
|
+
@use_experts_implementation
|
|
45
46
|
class DeepseekV2Experts(nn.Module):
|
|
46
47
|
"""Collection of expert weights stored as 3D tensors."""
|
|
47
48
|
|
|
@@ -184,7 +185,7 @@ class DeepseekV2RotaryEmbedding(nn.Module):
|
|
|
184
185
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
185
186
|
|
|
186
187
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
187
|
-
self.original_inv_freq =
|
|
188
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
188
189
|
|
|
189
190
|
@staticmethod
|
|
190
191
|
def compute_default_rope_parameters(
|
|
@@ -453,7 +454,9 @@ class DeepseekV2PreTrainedModel(PreTrainedModel):
|
|
|
453
454
|
_supports_flash_attn = True
|
|
454
455
|
_supports_sdpa = True
|
|
455
456
|
_supports_flex_attn = True
|
|
456
|
-
_can_compile_fullgraph =
|
|
457
|
+
_can_compile_fullgraph = (
|
|
458
|
+
is_grouped_mm_available()
|
|
459
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
457
460
|
_supports_attention_backend = True
|
|
458
461
|
_can_record_outputs = {
|
|
459
462
|
"hidden_states": DeepseekV2DecoderLayer,
|
|
@@ -24,7 +24,7 @@ from ... import initialization as init
|
|
|
24
24
|
from ...cache_utils import Cache
|
|
25
25
|
from ...modeling_rope_utils import RopeParameters, dynamic_rope_update
|
|
26
26
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
27
|
-
from ...utils import logging
|
|
27
|
+
from ...utils import is_grouped_mm_available, logging
|
|
28
28
|
from ...utils.generic import maybe_autocast
|
|
29
29
|
from ..llama.configuration_llama import LlamaConfig
|
|
30
30
|
from ..llama.modeling_llama import (
|
|
@@ -437,7 +437,9 @@ class DeepseekV2DecoderLayer(LlamaDecoderLayer):
|
|
|
437
437
|
|
|
438
438
|
|
|
439
439
|
class DeepseekV2PreTrainedModel(LlamaPreTrainedModel):
|
|
440
|
-
_can_compile_fullgraph =
|
|
440
|
+
_can_compile_fullgraph = (
|
|
441
|
+
is_grouped_mm_available()
|
|
442
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
441
443
|
|
|
442
444
|
@torch.no_grad()
|
|
443
445
|
def _init_weights(self, module):
|
|
@@ -16,7 +16,7 @@ from ... import initialization as init
|
|
|
16
16
|
from ...activations import ACT2FN
|
|
17
17
|
from ...cache_utils import Cache, DynamicCache
|
|
18
18
|
from ...generation import GenerationMixin
|
|
19
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
19
|
+
from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
20
20
|
from ...masking_utils import create_causal_mask
|
|
21
21
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
22
22
|
from ...modeling_layers import (
|
|
@@ -28,7 +28,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
|
28
28
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
29
29
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
30
30
|
from ...processing_utils import Unpack
|
|
31
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
31
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
32
32
|
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
33
33
|
from .configuration_deepseek_v3 import DeepseekV3Config
|
|
34
34
|
|
|
@@ -71,7 +71,7 @@ class DeepseekV3RotaryEmbedding(nn.Module):
|
|
|
71
71
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
72
72
|
|
|
73
73
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
74
|
-
self.original_inv_freq =
|
|
74
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
75
75
|
|
|
76
76
|
@staticmethod
|
|
77
77
|
def compute_default_rope_parameters(
|
|
@@ -150,6 +150,7 @@ class DeepseekV3TopkRouter(nn.Module):
|
|
|
150
150
|
return router_logits
|
|
151
151
|
|
|
152
152
|
|
|
153
|
+
@use_experts_implementation
|
|
153
154
|
class DeepseekV3NaiveMoe(nn.Module):
|
|
154
155
|
"""Collection of expert weights stored as 3D tensors."""
|
|
155
156
|
|
|
@@ -157,7 +158,7 @@ class DeepseekV3NaiveMoe(nn.Module):
|
|
|
157
158
|
super().__init__()
|
|
158
159
|
self.num_experts = config.num_local_experts
|
|
159
160
|
self.hidden_dim = config.hidden_size
|
|
160
|
-
self.intermediate_dim = config.
|
|
161
|
+
self.intermediate_dim = config.moe_intermediate_size
|
|
161
162
|
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
|
162
163
|
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
|
163
164
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
@@ -542,7 +543,9 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
|
|
|
542
543
|
_supports_flash_attn = True
|
|
543
544
|
_supports_sdpa = True
|
|
544
545
|
_supports_flex_attn = True
|
|
545
|
-
_can_compile_fullgraph =
|
|
546
|
+
_can_compile_fullgraph = (
|
|
547
|
+
is_grouped_mm_available()
|
|
548
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
546
549
|
_supports_attention_backend = True
|
|
547
550
|
_can_record_outputs = {
|
|
548
551
|
"hidden_states": DeepseekV3DecoderLayer,
|
|
@@ -555,6 +558,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
|
|
|
555
558
|
super()._init_weights(module)
|
|
556
559
|
if isinstance(module, DeepseekV3TopkRouter):
|
|
557
560
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
561
|
+
init.zeros_(module.e_score_correction_bias)
|
|
558
562
|
elif isinstance(module, DeepseekV3NaiveMoe):
|
|
559
563
|
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
|
|
560
564
|
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
|
|
@@ -12,7 +12,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
|
12
12
|
from ...modeling_layers import GenericForSequenceClassification, GenericForTokenClassification
|
|
13
13
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
14
14
|
from ...processing_utils import Unpack
|
|
15
|
-
from ...utils import logging
|
|
15
|
+
from ...utils import is_grouped_mm_available, logging
|
|
16
16
|
from ..llama.modeling_llama import (
|
|
17
17
|
LlamaDecoderLayer,
|
|
18
18
|
LlamaForCausalLM,
|
|
@@ -107,6 +107,7 @@ class DeepseekV3NaiveMoe(MixtralExperts):
|
|
|
107
107
|
def __init__(self, config):
|
|
108
108
|
super().__init__(config)
|
|
109
109
|
self.num_experts = config.num_local_experts
|
|
110
|
+
self.intermediate_dim = config.moe_intermediate_size
|
|
110
111
|
|
|
111
112
|
|
|
112
113
|
class DeepseekV3MoE(nn.Module):
|
|
@@ -303,7 +304,9 @@ class DeepseekV3DecoderLayer(LlamaDecoderLayer):
|
|
|
303
304
|
|
|
304
305
|
|
|
305
306
|
class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
|
|
306
|
-
_can_compile_fullgraph =
|
|
307
|
+
_can_compile_fullgraph = (
|
|
308
|
+
is_grouped_mm_available()
|
|
309
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
307
310
|
_keep_in_fp32_modules_strict = ["e_score_correction_bias"]
|
|
308
311
|
|
|
309
312
|
@torch.no_grad()
|
|
@@ -311,6 +314,7 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
|
|
|
311
314
|
PreTrainedModel._init_weights(self, module)
|
|
312
315
|
if isinstance(module, DeepseekV3TopkRouter):
|
|
313
316
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
317
|
+
init.zeros_(module.e_score_correction_bias)
|
|
314
318
|
elif isinstance(module, DeepseekV3NaiveMoe):
|
|
315
319
|
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
|
|
316
320
|
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
|
|
@@ -171,7 +171,6 @@ class DeepseekVLImageProcessorFast(BaseImageProcessorFast):
|
|
|
171
171
|
processed_images_grouped[shape] = stacked_images
|
|
172
172
|
|
|
173
173
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
174
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
175
174
|
|
|
176
175
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
177
176
|
|