transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -91,6 +91,7 @@ class EvollaSaProtRotaryEmbedding(nn.Module):
|
|
|
91
91
|
|
|
92
92
|
def __init__(self, dim: int):
|
|
93
93
|
super().__init__()
|
|
94
|
+
self.dim = dim
|
|
94
95
|
# Generate and save the inverse frequency buffer (non trainable)
|
|
95
96
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
|
96
97
|
self.register_buffer("inv_freq", inv_freq)
|
|
@@ -203,12 +204,19 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel):
|
|
|
203
204
|
],
|
|
204
205
|
}
|
|
205
206
|
|
|
207
|
+
def _init_weights(self, module):
|
|
208
|
+
super()._init_weights(module)
|
|
209
|
+
if isinstance(module, EvollaSaProtRotaryEmbedding):
|
|
210
|
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
|
|
211
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
212
|
+
|
|
206
213
|
|
|
207
214
|
class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
|
|
208
215
|
def __init__(self, config: SaProtConfig):
|
|
209
216
|
super().__init__(config)
|
|
210
217
|
self.embeddings = EvollaSaProtEmbeddings(config)
|
|
211
218
|
self.encoder = EvollaSaProtEncoder(config)
|
|
219
|
+
self.post_init()
|
|
212
220
|
|
|
213
221
|
def get_input_embeddings(self):
|
|
214
222
|
return self.embeddings.word_embeddings
|
|
@@ -86,7 +86,7 @@ class Exaone4RotaryEmbedding(nn.Module):
|
|
|
86
86
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
87
87
|
|
|
88
88
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
89
|
-
self.original_inv_freq =
|
|
89
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
90
90
|
|
|
91
91
|
@staticmethod
|
|
92
92
|
def compute_default_rope_parameters(
|
|
@@ -122,7 +122,7 @@ class FalconRotaryEmbedding(nn.Module):
|
|
|
122
122
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
123
123
|
|
|
124
124
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
125
|
-
self.original_inv_freq =
|
|
125
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
126
126
|
|
|
127
127
|
@staticmethod
|
|
128
128
|
def compute_default_rope_parameters(
|
|
@@ -521,8 +521,8 @@ class FalconFlashAttention2(FalconAttention):
|
|
|
521
521
|
else torch.get_autocast_gpu_dtype()
|
|
522
522
|
)
|
|
523
523
|
# Handle the case where the model is quantized
|
|
524
|
-
elif hasattr(self.config, "
|
|
525
|
-
target_dtype = self.config.
|
|
524
|
+
elif hasattr(self.config, "quantization_config"):
|
|
525
|
+
target_dtype = self.config.dtype
|
|
526
526
|
else:
|
|
527
527
|
target_dtype = self.query_key_value.weight.dtype
|
|
528
528
|
|
|
@@ -241,7 +241,7 @@ class FalconH1RotaryEmbedding(nn.Module):
|
|
|
241
241
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
242
242
|
|
|
243
243
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
244
|
-
self.original_inv_freq =
|
|
244
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
245
245
|
|
|
246
246
|
@staticmethod
|
|
247
247
|
def compute_default_rope_parameters(
|
|
@@ -1187,26 +1187,6 @@ class FalconH1DecoderLayer(GradientCheckpointingLayer):
|
|
|
1187
1187
|
return outputs
|
|
1188
1188
|
|
|
1189
1189
|
|
|
1190
|
-
@auto_docstring
|
|
1191
|
-
class FalconH1PreTrainedModel(PreTrainedModel):
|
|
1192
|
-
config: FalconH1Config
|
|
1193
|
-
base_model_prefix = "model"
|
|
1194
|
-
supports_gradient_checkpointing = True
|
|
1195
|
-
_no_split_modules = ["FalconH1DecoderLayer"]
|
|
1196
|
-
_skip_keys_device_placement = "past_key_values"
|
|
1197
|
-
_supports_flash_attn = True
|
|
1198
|
-
_supports_sdpa = True
|
|
1199
|
-
_is_stateful = True
|
|
1200
|
-
|
|
1201
|
-
@torch.no_grad()
|
|
1202
|
-
def _init_weights(self, module):
|
|
1203
|
-
super()._init_weights(module)
|
|
1204
|
-
if isinstance(module, FalconH1Mixer):
|
|
1205
|
-
init.ones_(module.dt_bias)
|
|
1206
|
-
init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
|
|
1207
|
-
init.ones_(module.D)
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
1190
|
def compute_mup_vector(config):
|
|
1211
1191
|
"""
|
|
1212
1192
|
Computes the MuP vector based on model configuration.
|
|
@@ -1244,6 +1224,30 @@ def compute_mup_vector(config):
|
|
|
1244
1224
|
return mup_vector
|
|
1245
1225
|
|
|
1246
1226
|
|
|
1227
|
+
@auto_docstring
|
|
1228
|
+
class FalconH1PreTrainedModel(PreTrainedModel):
|
|
1229
|
+
config: FalconH1Config
|
|
1230
|
+
base_model_prefix = "model"
|
|
1231
|
+
supports_gradient_checkpointing = True
|
|
1232
|
+
_no_split_modules = ["FalconH1DecoderLayer"]
|
|
1233
|
+
_skip_keys_device_placement = "past_key_values"
|
|
1234
|
+
_supports_flash_attn = True
|
|
1235
|
+
_supports_sdpa = True
|
|
1236
|
+
_is_stateful = True
|
|
1237
|
+
|
|
1238
|
+
@torch.no_grad()
|
|
1239
|
+
def _init_weights(self, module):
|
|
1240
|
+
super()._init_weights(module)
|
|
1241
|
+
if isinstance(module, FalconH1Mixer):
|
|
1242
|
+
init.ones_(module.dt_bias)
|
|
1243
|
+
init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
|
|
1244
|
+
init.ones_(module.D)
|
|
1245
|
+
elif isinstance(module, FalconH1Model):
|
|
1246
|
+
mup_vector = compute_mup_vector(module.config)
|
|
1247
|
+
for layer in module.layers:
|
|
1248
|
+
init.copy_(layer.mamba.mup_vector, mup_vector)
|
|
1249
|
+
|
|
1250
|
+
|
|
1247
1251
|
@auto_docstring
|
|
1248
1252
|
# Adapted from transformers.models.jamba.modeling_jamba.JambaModel
|
|
1249
1253
|
class FalconH1Model(FalconH1PreTrainedModel):
|
|
@@ -1269,7 +1273,7 @@ class FalconH1Model(FalconH1PreTrainedModel):
|
|
|
1269
1273
|
# Compute the MuP vector once and register it for all layers
|
|
1270
1274
|
mup_vector = compute_mup_vector(config)
|
|
1271
1275
|
for layer in self.layers:
|
|
1272
|
-
layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False)
|
|
1276
|
+
layer.mamba.register_buffer("mup_vector", mup_vector.clone(), persistent=False)
|
|
1273
1277
|
|
|
1274
1278
|
# Initialize weights and apply final processing
|
|
1275
1279
|
self.post_init()
|
|
@@ -1591,6 +1595,7 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin):
|
|
|
1591
1595
|
cache_position=None,
|
|
1592
1596
|
position_ids=None,
|
|
1593
1597
|
use_cache=True,
|
|
1598
|
+
is_first_iteration=False,
|
|
1594
1599
|
**kwargs,
|
|
1595
1600
|
):
|
|
1596
1601
|
# Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
|
|
@@ -1628,7 +1633,7 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin):
|
|
|
1628
1633
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
1629
1634
|
|
|
1630
1635
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
1631
|
-
if inputs_embeds is not None and
|
|
1636
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
1632
1637
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
1633
1638
|
else:
|
|
1634
1639
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
@@ -928,6 +928,10 @@ class FalconH1PreTrainedModel(PreTrainedModel):
|
|
|
928
928
|
init.ones_(module.dt_bias)
|
|
929
929
|
init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
|
|
930
930
|
init.ones_(module.D)
|
|
931
|
+
elif isinstance(module, FalconH1Model):
|
|
932
|
+
mup_vector = compute_mup_vector(module.config)
|
|
933
|
+
for layer in module.layers:
|
|
934
|
+
init.copy_(layer.mamba.mup_vector, mup_vector)
|
|
931
935
|
|
|
932
936
|
|
|
933
937
|
def compute_mup_vector(config):
|
|
@@ -992,7 +996,7 @@ class FalconH1Model(FalconH1PreTrainedModel):
|
|
|
992
996
|
# Compute the MuP vector once and register it for all layers
|
|
993
997
|
mup_vector = compute_mup_vector(config)
|
|
994
998
|
for layer in self.layers:
|
|
995
|
-
layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False)
|
|
999
|
+
layer.mamba.register_buffer("mup_vector", mup_vector.clone(), persistent=False)
|
|
996
1000
|
|
|
997
1001
|
# Initialize weights and apply final processing
|
|
998
1002
|
self.post_init()
|
|
@@ -1298,6 +1302,7 @@ class FalconH1ForCausalLM(LlamaForCausalLM):
|
|
|
1298
1302
|
cache_position=None,
|
|
1299
1303
|
position_ids=None,
|
|
1300
1304
|
use_cache=True,
|
|
1305
|
+
is_first_iteration=False,
|
|
1301
1306
|
**kwargs,
|
|
1302
1307
|
):
|
|
1303
1308
|
# Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
|
|
@@ -1335,7 +1340,7 @@ class FalconH1ForCausalLM(LlamaForCausalLM):
|
|
|
1335
1340
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
1336
1341
|
|
|
1337
1342
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
1338
|
-
if inputs_embeds is not None and
|
|
1343
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
1339
1344
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
1340
1345
|
else:
|
|
1341
1346
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
@@ -31,7 +31,7 @@ from ... import initialization as init
|
|
|
31
31
|
from ...activations import ACT2FN
|
|
32
32
|
from ...configuration_utils import PreTrainedConfig
|
|
33
33
|
from ...generation import GenerationMixin
|
|
34
|
-
from ...integrations
|
|
34
|
+
from ...integrations import lazy_load_kernel
|
|
35
35
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
36
36
|
from ...modeling_utils import PreTrainedModel
|
|
37
37
|
from ...utils import ModelOutput, auto_docstring, logging
|
|
@@ -345,7 +345,7 @@ class FalconMambaMixer(nn.Module):
|
|
|
345
345
|
|
|
346
346
|
# In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
|
|
347
347
|
# at the price of a small overhead.
|
|
348
|
-
if hasattr(self.config, "
|
|
348
|
+
if hasattr(self.config, "quantization_config"):
|
|
349
349
|
discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
|
|
350
350
|
else:
|
|
351
351
|
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
|
|
@@ -613,6 +613,9 @@ class FalconMambaPreTrainedModel(PreTrainedModel):
|
|
|
613
613
|
init.ones_(module.weight)
|
|
614
614
|
elif isinstance(module, nn.Embedding):
|
|
615
615
|
init.normal_(module.weight, std=std)
|
|
616
|
+
if isinstance(module, FalconMambaMixer):
|
|
617
|
+
init.ones_(module.b_c_rms)
|
|
618
|
+
init.ones_(module.dt_rms)
|
|
616
619
|
|
|
617
620
|
|
|
618
621
|
@dataclass
|
|
@@ -811,6 +814,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
|
|
|
811
814
|
cache_params: Optional[FalconMambaCache] = None,
|
|
812
815
|
cache_position: Optional[torch.LongTensor] = None,
|
|
813
816
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
817
|
+
is_first_iteration: Optional[bool] = False,
|
|
814
818
|
**kwargs,
|
|
815
819
|
):
|
|
816
820
|
# Overwritten -- uses `cache_params` as opposed to `past_key_values`
|
|
@@ -19,6 +19,7 @@ from typing import Optional
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch import nn
|
|
21
21
|
|
|
22
|
+
from ... import initialization as init
|
|
22
23
|
from ...utils import auto_docstring, logging
|
|
23
24
|
from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling
|
|
24
25
|
from ..mamba.configuration_mamba import MambaConfig
|
|
@@ -357,7 +358,7 @@ class FalconMambaMixer(MambaMixer):
|
|
|
357
358
|
|
|
358
359
|
# In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
|
|
359
360
|
# at the price of a small overhead.
|
|
360
|
-
if hasattr(self.config, "
|
|
361
|
+
if hasattr(self.config, "quantization_config"):
|
|
361
362
|
discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
|
|
362
363
|
else:
|
|
363
364
|
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
|
|
@@ -529,7 +530,11 @@ class FalconMambaBlock(MambaBlock):
|
|
|
529
530
|
|
|
530
531
|
@auto_docstring
|
|
531
532
|
class FalconMambaPreTrainedModel(MambaPreTrainedModel):
|
|
532
|
-
|
|
533
|
+
def _init_weights(self, module):
|
|
534
|
+
super()._init_weights(module)
|
|
535
|
+
if isinstance(module, FalconMambaMixer):
|
|
536
|
+
init.ones_(module.b_c_rms)
|
|
537
|
+
init.ones_(module.dt_rms)
|
|
533
538
|
|
|
534
539
|
|
|
535
540
|
class FalconMambaOutput(MambaOutput):
|
|
@@ -430,6 +430,7 @@ class FastVlmForConditionalGeneration(FastVlmPreTrainedModel, GenerationMixin):
|
|
|
430
430
|
attention_mask=None,
|
|
431
431
|
cache_position=None,
|
|
432
432
|
logits_to_keep=None,
|
|
433
|
+
is_first_iteration=False,
|
|
433
434
|
**kwargs,
|
|
434
435
|
):
|
|
435
436
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -441,12 +442,15 @@ class FastVlmForConditionalGeneration(FastVlmPreTrainedModel, GenerationMixin):
|
|
|
441
442
|
attention_mask=attention_mask,
|
|
442
443
|
cache_position=cache_position,
|
|
443
444
|
logits_to_keep=logits_to_keep,
|
|
445
|
+
is_first_iteration=is_first_iteration,
|
|
444
446
|
**kwargs,
|
|
445
447
|
)
|
|
446
448
|
|
|
447
|
-
if
|
|
448
|
-
#
|
|
449
|
-
#
|
|
449
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
450
|
+
# Pixel values are used only in the first iteration if available
|
|
451
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
452
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
453
|
+
# iteration with a question and cached system prompt (continue generate from cache)
|
|
450
454
|
model_inputs["pixel_values"] = pixel_values
|
|
451
455
|
|
|
452
456
|
return model_inputs
|
|
@@ -727,19 +727,20 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
|
|
|
727
727
|
self.embed_dim = config.hidden_size
|
|
728
728
|
self.input_scale = math.sqrt(self.embed_dim)
|
|
729
729
|
self.dropout = nn.Dropout(p=module_config["positional_dropout_rate"])
|
|
730
|
-
self.pos_enc = None
|
|
731
730
|
self.max_len = 5000
|
|
732
|
-
self.
|
|
731
|
+
self.register_buffer(
|
|
732
|
+
"pos_enc", self.extend_pos_enc(torch.tensor(0.0).expand(1, self.max_len)), persistent=False
|
|
733
|
+
)
|
|
733
734
|
|
|
734
|
-
def extend_pos_enc(self, x):
|
|
735
|
+
def extend_pos_enc(self, x, pos_enc=None):
|
|
735
736
|
"""Reset the positional encodings."""
|
|
736
|
-
if
|
|
737
|
+
if pos_enc is not None:
|
|
737
738
|
# self.pos_enc contains both positive and negative parts
|
|
738
739
|
# the length of self.pos_enc is 2 * input_len - 1
|
|
739
|
-
if
|
|
740
|
-
if
|
|
741
|
-
|
|
742
|
-
return
|
|
740
|
+
if pos_enc.size(1) >= x.size(1) * 2 - 1:
|
|
741
|
+
if pos_enc.dtype != x.dtype or pos_enc.device != x.device:
|
|
742
|
+
pos_enc = pos_enc.to(dtype=x.dtype, device=x.device)
|
|
743
|
+
return pos_enc
|
|
743
744
|
# Suppose `i` means to the position of query vector and `j` means the
|
|
744
745
|
# position of key vector. We use position relative positions when keys
|
|
745
746
|
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
|
@@ -760,7 +761,7 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
|
|
|
760
761
|
pos_enc_positive = torch.flip(pos_enc_positive, [0]).unsqueeze(0)
|
|
761
762
|
pos_enc_negative = pos_enc_negative[1:].unsqueeze(0)
|
|
762
763
|
pos_enc = torch.cat([pos_enc_positive, pos_enc_negative], dim=1)
|
|
763
|
-
|
|
764
|
+
return pos_enc.to(device=x.device, dtype=x.dtype)
|
|
764
765
|
|
|
765
766
|
def forward(self, feature_representation):
|
|
766
767
|
"""
|
|
@@ -771,7 +772,7 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
|
|
|
771
772
|
Returns:
|
|
772
773
|
`torch.Tensor`: Encoded tensor (batch_size, time, `*`).
|
|
773
774
|
"""
|
|
774
|
-
self.extend_pos_enc(feature_representation)
|
|
775
|
+
self.pos_enc = self.extend_pos_enc(feature_representation, self.pos_enc)
|
|
775
776
|
hidden_states = feature_representation * self.input_scale
|
|
776
777
|
center_idx = self.pos_enc.size(1) // 2
|
|
777
778
|
pos_emb = self.pos_enc[:, center_idx - hidden_states.size(1) + 1 : center_idx + hidden_states.size(1)]
|
|
@@ -1010,6 +1011,10 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel):
|
|
|
1010
1011
|
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
|
|
1011
1012
|
init.zeros_(module.bias)
|
|
1012
1013
|
init.ones_(module.weight)
|
|
1014
|
+
if getattr(module, "running_mean", None) is not None:
|
|
1015
|
+
init.zeros_(module.running_mean)
|
|
1016
|
+
init.ones_(module.running_var)
|
|
1017
|
+
init.zeros_(module.num_batches_tracked)
|
|
1013
1018
|
elif isinstance(module, nn.Embedding):
|
|
1014
1019
|
init.normal_(module.weight)
|
|
1015
1020
|
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
@@ -1018,6 +1023,8 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel):
|
|
|
1018
1023
|
elif isinstance(module, FastSpeech2ConformerAttention):
|
|
1019
1024
|
init.xavier_uniform_(module.pos_bias_u)
|
|
1020
1025
|
init.xavier_uniform_(module.pos_bias_v)
|
|
1026
|
+
elif isinstance(module, FastSpeech2ConformerRelPositionalEncoding):
|
|
1027
|
+
init.copy_(module.pos_enc, module.extend_pos_enc(torch.tensor(0.0).expand(1, module.max_len)))
|
|
1021
1028
|
|
|
1022
1029
|
def _set_gradient_checkpointing(self, module, value=False):
|
|
1023
1030
|
if isinstance(module, FastSpeech2ConformerEncoder):
|
|
@@ -1410,6 +1417,12 @@ class FastSpeech2ConformerHifiGan(PreTrainedModel):
|
|
|
1410
1417
|
# Initialize weights and apply final processing
|
|
1411
1418
|
self.post_init()
|
|
1412
1419
|
|
|
1420
|
+
def _init_weights(self, module):
|
|
1421
|
+
super()._init_weights(module)
|
|
1422
|
+
if isinstance(module, FastSpeech2ConformerHifiGan):
|
|
1423
|
+
init.zeros_(module.mean)
|
|
1424
|
+
init.ones_(module.scale)
|
|
1425
|
+
|
|
1413
1426
|
def apply_weight_norm(self):
|
|
1414
1427
|
weight_norm = nn.utils.weight_norm
|
|
1415
1428
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
|
@@ -660,9 +660,6 @@ class FlaubertPreTrainedModel(PreTrainedModel):
|
|
|
660
660
|
config: FlaubertConfig
|
|
661
661
|
base_model_prefix = "transformer"
|
|
662
662
|
|
|
663
|
-
def __init__(self, *inputs, **kwargs):
|
|
664
|
-
super().__init__(*inputs, **kwargs)
|
|
665
|
-
|
|
666
663
|
@property
|
|
667
664
|
def dummy_inputs(self):
|
|
668
665
|
inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
|
@@ -690,15 +687,17 @@ class FlaubertPreTrainedModel(PreTrainedModel):
|
|
|
690
687
|
if isinstance(module, nn.LayerNorm):
|
|
691
688
|
init.zeros_(module.bias)
|
|
692
689
|
init.ones_(module.weight)
|
|
693
|
-
if isinstance(module, FlaubertModel)
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
690
|
+
if isinstance(module, FlaubertModel):
|
|
691
|
+
if self.config.sinusoidal_embeddings:
|
|
692
|
+
init.copy_(
|
|
693
|
+
module.position_embeddings.weight,
|
|
694
|
+
create_sinusoidal_embeddings(
|
|
695
|
+
self.config.max_position_embeddings,
|
|
696
|
+
self.config.emb_dim,
|
|
697
|
+
out=torch.empty_like(module.position_embeddings.weight),
|
|
698
|
+
),
|
|
699
|
+
)
|
|
700
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
702
701
|
|
|
703
702
|
|
|
704
703
|
@auto_docstring
|
|
@@ -760,15 +759,15 @@ class FlaubertModel(FlaubertPreTrainedModel):
|
|
|
760
759
|
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
|
|
761
760
|
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
|
762
761
|
|
|
763
|
-
# Initialize weights and apply final processing
|
|
764
|
-
self.post_init()
|
|
765
|
-
|
|
766
762
|
self.layerdrop = getattr(config, "layerdrop", 0.0)
|
|
767
763
|
self.pre_norm = getattr(config, "pre_norm", False)
|
|
768
764
|
self.register_buffer(
|
|
769
765
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
|
770
766
|
)
|
|
771
767
|
|
|
768
|
+
# Initialize weights and apply final processing
|
|
769
|
+
self.post_init()
|
|
770
|
+
|
|
772
771
|
# Copied from transformers.models.xlm.modeling_xlm.XLMModel.get_input_embeddings
|
|
773
772
|
def get_input_embeddings(self):
|
|
774
773
|
return self.embeddings
|
|
@@ -306,7 +306,6 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
|
|
|
306
306
|
processed_images_grouped[shape] = stacked_images
|
|
307
307
|
|
|
308
308
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
309
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
310
309
|
|
|
311
310
|
return processed_images
|
|
312
311
|
|
|
@@ -397,7 +396,6 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
|
|
|
397
396
|
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
|
|
398
397
|
)
|
|
399
398
|
masks = [mask_generator() for _ in range(len(images))]
|
|
400
|
-
masks = torch.stack(masks, dim=0) if return_tensors else masks
|
|
401
399
|
data["bool_masked_pos"] = masks
|
|
402
400
|
|
|
403
401
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
@@ -677,6 +677,9 @@ class FlavaPreTrainedModel(PreTrainedModel):
|
|
|
677
677
|
init.zeros_(module.position_embeddings)
|
|
678
678
|
if module.mask_token is not None:
|
|
679
679
|
init.zeros_(module.mask_token)
|
|
680
|
+
elif isinstance(module, FlavaTextEmbeddings):
|
|
681
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
682
|
+
init.zeros_(module.token_type_ids)
|
|
680
683
|
elif isinstance(module, FlavaMultimodalModel):
|
|
681
684
|
if module.use_cls_token:
|
|
682
685
|
init.zeros_(module.cls_token)
|
|
@@ -1107,7 +1110,7 @@ class FlavaModel(FlavaPreTrainedModel):
|
|
|
1107
1110
|
output_hidden_states: bool = True,
|
|
1108
1111
|
return_dict: Optional[bool] = None,
|
|
1109
1112
|
**kwargs,
|
|
1110
|
-
) -> Union[tuple,
|
|
1113
|
+
) -> Union[tuple, FlavaModelOutput]:
|
|
1111
1114
|
r"""
|
|
1112
1115
|
input_ids (`torch.LongTensor` of shape `(batch_size, image_num_patches + text_seq_len)`):
|
|
1113
1116
|
Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
|
|
@@ -30,14 +30,14 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
33
|
+
from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask
|
|
35
35
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
36
36
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
37
37
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
38
38
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
40
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
41
41
|
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
42
42
|
from .configuration_flex_olmo import FlexOlmoConfig
|
|
43
43
|
|
|
@@ -80,7 +80,7 @@ class FlexOlmoRotaryEmbedding(nn.Module):
|
|
|
80
80
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
81
81
|
|
|
82
82
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
83
|
-
self.original_inv_freq =
|
|
83
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
84
84
|
|
|
85
85
|
@staticmethod
|
|
86
86
|
def compute_default_rope_parameters(
|
|
@@ -293,6 +293,7 @@ class FlexOlmoAttention(nn.Module):
|
|
|
293
293
|
return attn_output, attn_weights
|
|
294
294
|
|
|
295
295
|
|
|
296
|
+
@use_experts_implementation
|
|
296
297
|
class FlexOlmoExperts(nn.Module):
|
|
297
298
|
"""Collection of expert weights stored as 3D tensors."""
|
|
298
299
|
|
|
@@ -421,7 +422,9 @@ class FlexOlmoPreTrainedModel(PreTrainedModel):
|
|
|
421
422
|
_supports_flash_attn = True
|
|
422
423
|
_supports_sdpa = True
|
|
423
424
|
_supports_flex_attn = True
|
|
424
|
-
_can_compile_fullgraph =
|
|
425
|
+
_can_compile_fullgraph = (
|
|
426
|
+
is_grouped_mm_available()
|
|
427
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
425
428
|
_supports_attention_backend = True
|
|
426
429
|
_can_record_outputs = {
|
|
427
430
|
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
|
|
@@ -26,6 +26,7 @@ from typing import Any, Optional, Union
|
|
|
26
26
|
import torch.nn as nn
|
|
27
27
|
import torch.nn.functional as F
|
|
28
28
|
|
|
29
|
+
from ... import initialization as init
|
|
29
30
|
from ...activations import ACT2FN
|
|
30
31
|
from ...cache_utils import Cache
|
|
31
32
|
from ...generation import GenerationMixin
|
|
@@ -629,6 +630,18 @@ class Florence2PreTrainedModel(PreTrainedModel):
|
|
|
629
630
|
_supports_attention_backend = False
|
|
630
631
|
config_class = Florence2Config
|
|
631
632
|
|
|
633
|
+
def _init_weights(self, module):
|
|
634
|
+
super()._init_weights(module)
|
|
635
|
+
if isinstance(module, Florence2VisionPositionalEmbeddingCosine1D):
|
|
636
|
+
pos_idx_to_embed = torch.empty((module.max_seq_len, module.embed_dim))
|
|
637
|
+
sine, cosine = module.get_sinusoid_embeddings(
|
|
638
|
+
max_positions=module.max_seq_len,
|
|
639
|
+
embed_dim=module.embed_dim,
|
|
640
|
+
)
|
|
641
|
+
pos_idx_to_embed[:, 0::2] = sine
|
|
642
|
+
pos_idx_to_embed[:, 1::2] = cosine
|
|
643
|
+
init.copy_(module.pos_idx_to_embed, pos_idx_to_embed)
|
|
644
|
+
|
|
632
645
|
|
|
633
646
|
@auto_docstring(
|
|
634
647
|
custom_intro="""
|
|
@@ -937,6 +950,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
|
|
937
950
|
attention_mask=None,
|
|
938
951
|
cache_position=None,
|
|
939
952
|
logits_to_keep=None,
|
|
953
|
+
is_first_iteration=False,
|
|
940
954
|
**kwargs,
|
|
941
955
|
):
|
|
942
956
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -948,12 +962,15 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
|
|
948
962
|
attention_mask=attention_mask,
|
|
949
963
|
cache_position=cache_position,
|
|
950
964
|
logits_to_keep=logits_to_keep,
|
|
965
|
+
is_first_iteration=is_first_iteration,
|
|
951
966
|
**kwargs,
|
|
952
967
|
)
|
|
953
968
|
|
|
954
|
-
if
|
|
955
|
-
#
|
|
956
|
-
#
|
|
969
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
970
|
+
# Pixel values are used only in the first iteration if available
|
|
971
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
972
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
973
|
+
# iteration with a question and cached system prompt (continue generate from cache)
|
|
957
974
|
model_inputs["pixel_values"] = pixel_values
|
|
958
975
|
|
|
959
976
|
return model_inputs
|
|
@@ -22,6 +22,7 @@ import numpy as np
|
|
|
22
22
|
import torch.nn as nn
|
|
23
23
|
import torch.nn.functional as F
|
|
24
24
|
|
|
25
|
+
from ... import initialization as init
|
|
25
26
|
from ...activations import ACT2FN
|
|
26
27
|
from ...cache_utils import Cache
|
|
27
28
|
from ...configuration_utils import PreTrainedConfig
|
|
@@ -1500,6 +1501,18 @@ class Florence2PreTrainedModel(LlavaPreTrainedModel):
|
|
|
1500
1501
|
|
|
1501
1502
|
_supports_attention_backend = False
|
|
1502
1503
|
|
|
1504
|
+
def _init_weights(self, module):
|
|
1505
|
+
PreTrainedModel._init_weights(self, module)
|
|
1506
|
+
if isinstance(module, Florence2VisionPositionalEmbeddingCosine1D):
|
|
1507
|
+
pos_idx_to_embed = torch.empty((module.max_seq_len, module.embed_dim))
|
|
1508
|
+
sine, cosine = module.get_sinusoid_embeddings(
|
|
1509
|
+
max_positions=module.max_seq_len,
|
|
1510
|
+
embed_dim=module.embed_dim,
|
|
1511
|
+
)
|
|
1512
|
+
pos_idx_to_embed[:, 0::2] = sine
|
|
1513
|
+
pos_idx_to_embed[:, 1::2] = cosine
|
|
1514
|
+
init.copy_(module.pos_idx_to_embed, pos_idx_to_embed)
|
|
1515
|
+
|
|
1503
1516
|
|
|
1504
1517
|
@auto_docstring(
|
|
1505
1518
|
custom_intro="""
|
|
@@ -23,6 +23,7 @@ import torch
|
|
|
23
23
|
from torch import nn
|
|
24
24
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
25
25
|
|
|
26
|
+
from ... import initialization as init
|
|
26
27
|
from ...utils import auto_docstring, is_scipy_available
|
|
27
28
|
|
|
28
29
|
|
|
@@ -374,6 +375,12 @@ class FNetPreTrainedModel(PreTrainedModel):
|
|
|
374
375
|
base_model_prefix = "fnet"
|
|
375
376
|
supports_gradient_checkpointing = True
|
|
376
377
|
|
|
378
|
+
def _init_weights(self, module):
|
|
379
|
+
super()._init_weights(module)
|
|
380
|
+
if isinstance(module, FNetEmbeddings):
|
|
381
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
382
|
+
init.zeros_(module.token_type_ids)
|
|
383
|
+
|
|
377
384
|
|
|
378
385
|
@dataclass
|
|
379
386
|
@auto_docstring(
|
|
@@ -94,7 +94,7 @@ class FuyuBatchFeature(BatchFeature):
|
|
|
94
94
|
The outputs dictionary from the processors contains a mix of tensors and lists of tensors.
|
|
95
95
|
"""
|
|
96
96
|
|
|
97
|
-
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
|
|
97
|
+
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None, **kwargs):
|
|
98
98
|
"""
|
|
99
99
|
Convert the inner content to tensors.
|
|
100
100
|
|
|
@@ -359,6 +359,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
|
|
|
359
359
|
image_patches=None,
|
|
360
360
|
image_patches_indices=None,
|
|
361
361
|
cache_position=None,
|
|
362
|
+
is_first_iteration=False,
|
|
362
363
|
**kwargs,
|
|
363
364
|
):
|
|
364
365
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -371,10 +372,11 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
|
|
|
371
372
|
image_patches=image_patches,
|
|
372
373
|
image_patches_indices=image_patches_indices,
|
|
373
374
|
cache_position=cache_position,
|
|
375
|
+
is_first_iteration=is_first_iteration,
|
|
374
376
|
**kwargs,
|
|
375
377
|
)
|
|
376
378
|
|
|
377
|
-
if
|
|
379
|
+
if not is_first_iteration and kwargs.get("use_cache", True):
|
|
378
380
|
# set image_patches and image_patches_indices to `None` for decoding stage
|
|
379
381
|
model_inputs["image_patches_indices"] = None
|
|
380
382
|
model_inputs["image_patches"] = None
|