transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -29,7 +29,12 @@ from ... import initialization as init
|
|
|
29
29
|
from ...activations import ACT2FN
|
|
30
30
|
from ...cache_utils import Cache, DynamicCache
|
|
31
31
|
from ...generation import GenerationMixin
|
|
32
|
-
from ...integrations import
|
|
32
|
+
from ...integrations import (
|
|
33
|
+
use_experts_implementation,
|
|
34
|
+
use_kernel_forward_from_hub,
|
|
35
|
+
use_kernel_func_from_hub,
|
|
36
|
+
use_kernelized_func,
|
|
37
|
+
)
|
|
33
38
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
34
39
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
35
40
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -37,7 +42,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
|
37
42
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
38
43
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
44
|
from ...processing_utils import Unpack
|
|
40
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
45
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
41
46
|
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
42
47
|
from .configuration_dots1 import Dots1Config
|
|
43
48
|
|
|
@@ -80,7 +85,7 @@ class Dots1RotaryEmbedding(nn.Module):
|
|
|
80
85
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
81
86
|
|
|
82
87
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
83
|
-
self.original_inv_freq =
|
|
88
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
84
89
|
|
|
85
90
|
@staticmethod
|
|
86
91
|
def compute_default_rope_parameters(
|
|
@@ -308,6 +313,7 @@ class Dots1TopkRouter(nn.Module):
|
|
|
308
313
|
return router_logits
|
|
309
314
|
|
|
310
315
|
|
|
316
|
+
@use_experts_implementation
|
|
311
317
|
class Dots1NaiveMoe(nn.Module):
|
|
312
318
|
"""Collection of expert weights stored as 3D tensors."""
|
|
313
319
|
|
|
@@ -315,7 +321,7 @@ class Dots1NaiveMoe(nn.Module):
|
|
|
315
321
|
super().__init__()
|
|
316
322
|
self.num_experts = config.num_local_experts
|
|
317
323
|
self.hidden_dim = config.hidden_size
|
|
318
|
-
self.intermediate_dim = config.
|
|
324
|
+
self.intermediate_dim = config.moe_intermediate_size
|
|
319
325
|
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
|
320
326
|
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
|
321
327
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
@@ -463,7 +469,9 @@ class Dots1PreTrainedModel(PreTrainedModel):
|
|
|
463
469
|
_supports_flash_attn = True
|
|
464
470
|
_supports_sdpa = True
|
|
465
471
|
_supports_flex_attn = True
|
|
466
|
-
_can_compile_fullgraph =
|
|
472
|
+
_can_compile_fullgraph = (
|
|
473
|
+
is_grouped_mm_available()
|
|
474
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
467
475
|
_supports_attention_backend = True
|
|
468
476
|
_can_record_outputs = {
|
|
469
477
|
"hidden_states": Dots1DecoderLayer,
|
|
@@ -476,6 +484,7 @@ class Dots1PreTrainedModel(PreTrainedModel):
|
|
|
476
484
|
super()._init_weights(module)
|
|
477
485
|
if isinstance(module, Dots1TopkRouter):
|
|
478
486
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
487
|
+
init.zeros_(module.e_score_correction_bias)
|
|
479
488
|
elif isinstance(module, Dots1NaiveMoe):
|
|
480
489
|
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
|
|
481
490
|
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
|
|
@@ -102,7 +102,7 @@ class DPTConfig(PreTrainedConfig):
|
|
|
102
102
|
Used only for the `hybrid` embedding type. The shape of the feature maps of the backbone.
|
|
103
103
|
neck_ignore_stages (`list[int]`, *optional*, defaults to `[0, 1]`):
|
|
104
104
|
Used only for the `hybrid` embedding type. The stages of the readout layers to ignore.
|
|
105
|
-
backbone_config (`Union[dict
|
|
105
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `BitConfig()`):
|
|
106
106
|
The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
|
|
107
107
|
leverage the [`AutoBackbone`] API.
|
|
108
108
|
backbone (`str`, *optional*):
|
|
@@ -225,8 +225,7 @@ class DPTImageProcessorFast(BaseImageProcessorFast):
|
|
|
225
225
|
processed_images_grouped[shape] = stacked_images
|
|
226
226
|
|
|
227
227
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
228
|
-
|
|
229
|
-
return BatchFeature(data={"pixel_values": processed_images})
|
|
228
|
+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
230
229
|
|
|
231
230
|
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
|
|
232
231
|
"""
|
|
@@ -228,8 +228,7 @@ class DPTImageProcessorFast(BeitImageProcessorFast):
|
|
|
228
228
|
processed_images_grouped[shape] = stacked_images
|
|
229
229
|
|
|
230
230
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
231
|
-
|
|
232
|
-
return BatchFeature(data={"pixel_values": processed_images})
|
|
231
|
+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
233
232
|
|
|
234
233
|
def post_process_depth_estimation(
|
|
235
234
|
self,
|
|
@@ -33,7 +33,7 @@ class EdgeTamVisionConfig(PreTrainedConfig):
|
|
|
33
33
|
documentation from [`PreTrainedConfig`] for more information.
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
|
-
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional
|
|
36
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `timm/repvit_m1.dist_in1k`):
|
|
37
37
|
Configuration for the vision backbone. This is used to instantiate the backbone using
|
|
38
38
|
`AutoModel.from_config`.
|
|
39
39
|
backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`):
|
|
@@ -30,7 +30,7 @@ import torch.nn as nn
|
|
|
30
30
|
import torch.nn.functional as F
|
|
31
31
|
from torch import Tensor
|
|
32
32
|
|
|
33
|
-
from transformers.utils.generic import OutputRecorder
|
|
33
|
+
from transformers.utils.generic import OutputRecorder
|
|
34
34
|
|
|
35
35
|
from ... import initialization as init
|
|
36
36
|
from ...activations import ACT2FN
|
|
@@ -39,6 +39,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
40
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
41
41
|
from ...utils import ModelOutput, auto_docstring
|
|
42
|
+
from ...utils.generic import TransformersKwargs, check_model_inputs
|
|
42
43
|
from ..auto import AutoModel
|
|
43
44
|
from .configuration_edgetam import (
|
|
44
45
|
EdgeTamConfig,
|
|
@@ -50,7 +51,7 @@ from .configuration_edgetam import (
|
|
|
50
51
|
|
|
51
52
|
# fix this in modular
|
|
52
53
|
if True:
|
|
53
|
-
from
|
|
54
|
+
from ..timm_wrapper.modeling_timm_wrapper import TimmWrapperModel
|
|
54
55
|
|
|
55
56
|
|
|
56
57
|
class EdgeTamLayerNorm(nn.LayerNorm):
|
|
@@ -315,6 +316,8 @@ class EdgeTamPreTrainedModel(PreTrainedModel):
|
|
|
315
316
|
if isinstance(module, EdgeTamModel):
|
|
316
317
|
if module.no_memory_embedding is not None:
|
|
317
318
|
init.zeros_(module.no_memory_embedding)
|
|
319
|
+
elif hasattr(module, "positional_embedding"):
|
|
320
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
318
321
|
|
|
319
322
|
|
|
320
323
|
# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
|
|
@@ -19,8 +19,17 @@ from typing import Optional, Union
|
|
|
19
19
|
import torch
|
|
20
20
|
import torch.utils.checkpoint
|
|
21
21
|
|
|
22
|
-
from
|
|
23
|
-
from
|
|
22
|
+
from ... import initialization as init
|
|
23
|
+
from ...configuration_utils import PreTrainedConfig
|
|
24
|
+
from ...modeling_utils import PreTrainedModel
|
|
25
|
+
from ...processing_utils import Unpack
|
|
26
|
+
from ...utils import (
|
|
27
|
+
auto_docstring,
|
|
28
|
+
)
|
|
29
|
+
from ...utils.generic import TransformersKwargs, check_model_inputs
|
|
30
|
+
from ..auto import CONFIG_MAPPING, AutoConfig
|
|
31
|
+
from ..sam2.configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig
|
|
32
|
+
from ..sam2.modeling_sam2 import (
|
|
24
33
|
Sam2Attention,
|
|
25
34
|
Sam2FeedForward,
|
|
26
35
|
Sam2LayerNorm,
|
|
@@ -30,21 +39,11 @@ from transformers.models.sam2.modeling_sam2 import (
|
|
|
30
39
|
Sam2VisionEncoderOutput,
|
|
31
40
|
Sam2VisionModel,
|
|
32
41
|
)
|
|
33
|
-
from transformers.utils.generic import TransformersKwargs, check_model_inputs
|
|
34
|
-
|
|
35
|
-
from ... import initialization as init
|
|
36
|
-
from ...configuration_utils import PreTrainedConfig
|
|
37
|
-
from ...modeling_utils import PreTrainedModel
|
|
38
|
-
from ...processing_utils import Unpack
|
|
39
|
-
from ...utils import (
|
|
40
|
-
auto_docstring,
|
|
41
|
-
)
|
|
42
|
-
from ..auto import CONFIG_MAPPING, AutoConfig
|
|
43
42
|
|
|
44
43
|
|
|
45
44
|
# fix this in modular
|
|
46
45
|
if True:
|
|
47
|
-
from
|
|
46
|
+
from ..timm_wrapper.modeling_timm_wrapper import TimmWrapperModel
|
|
48
47
|
|
|
49
48
|
|
|
50
49
|
class EdgeTamVisionConfig(PreTrainedConfig):
|
|
@@ -58,7 +57,7 @@ class EdgeTamVisionConfig(PreTrainedConfig):
|
|
|
58
57
|
documentation from [`PreTrainedConfig`] for more information.
|
|
59
58
|
|
|
60
59
|
Args:
|
|
61
|
-
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional
|
|
60
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `timm/repvit_m1.dist_in1k`):
|
|
62
61
|
Configuration for the vision backbone. This is used to instantiate the backbone using
|
|
63
62
|
`AutoModel.from_config`.
|
|
64
63
|
backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`):
|
|
@@ -181,6 +180,8 @@ class EdgeTamPreTrainedModel(Sam2PreTrainedModel):
|
|
|
181
180
|
if isinstance(module, EdgeTamModel):
|
|
182
181
|
if module.no_memory_embedding is not None:
|
|
183
182
|
init.zeros_(module.no_memory_embedding)
|
|
183
|
+
elif hasattr(module, "positional_embedding"):
|
|
184
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
184
185
|
|
|
185
186
|
|
|
186
187
|
@auto_docstring(
|
|
@@ -152,24 +152,17 @@ class EdgeTamVideoVisionRotaryEmbedding(nn.Module):
|
|
|
152
152
|
|
|
153
153
|
def __init__(self, config: EdgeTamVideoConfig, end_x: Optional[int] = None, end_y: Optional[int] = None):
|
|
154
154
|
super().__init__()
|
|
155
|
-
dim = config.memory_attention_hidden_size // (
|
|
155
|
+
self.dim = config.memory_attention_hidden_size // (
|
|
156
156
|
config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
|
|
157
157
|
)
|
|
158
158
|
# Ensure even dimension for proper axial splitting
|
|
159
|
-
if dim % 4 != 0:
|
|
159
|
+
if self.dim % 4 != 0:
|
|
160
160
|
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
|
|
161
|
-
end_x, end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
|
|
162
|
-
|
|
161
|
+
self.end_x, self.end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
|
|
162
|
+
self.memory_attention_rope_theta = config.memory_attention_rope_theta
|
|
163
163
|
|
|
164
|
-
# Generate 2D position indices for axial rotary embedding
|
|
165
|
-
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
166
|
-
x_positions = flattened_indices % end_x
|
|
167
|
-
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
|
|
168
|
-
freqs_x = torch.outer(x_positions, freqs).float()
|
|
169
|
-
freqs_y = torch.outer(y_positions, freqs).float()
|
|
170
|
-
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
171
|
-
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
172
164
|
# directly register the cos and sin embeddings as we have a fixed feature shape
|
|
165
|
+
inv_freq = self.create_inv_freq()
|
|
173
166
|
self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
|
|
174
167
|
self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
|
|
175
168
|
|
|
@@ -178,6 +171,20 @@ class EdgeTamVideoVisionRotaryEmbedding(nn.Module):
|
|
|
178
171
|
# As the feature map size is fixed, we can just return the pre-computed embeddings.
|
|
179
172
|
return self.rope_embeddings_cos, self.rope_embeddings_sin
|
|
180
173
|
|
|
174
|
+
def create_inv_freq(self):
|
|
175
|
+
freqs = 1.0 / (
|
|
176
|
+
self.memory_attention_rope_theta ** (torch.arange(0, self.dim, 4)[: (self.dim // 4)].float() / self.dim)
|
|
177
|
+
)
|
|
178
|
+
# Generate 2D position indices for axial rotary embedding
|
|
179
|
+
flattened_indices = torch.arange(self.end_x * self.end_y, dtype=torch.long)
|
|
180
|
+
x_positions = flattened_indices % self.end_x
|
|
181
|
+
y_positions = torch.div(flattened_indices, self.end_x, rounding_mode="floor")
|
|
182
|
+
freqs_x = torch.outer(x_positions, freqs).float()
|
|
183
|
+
freqs_y = torch.outer(y_positions, freqs).float()
|
|
184
|
+
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
185
|
+
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
186
|
+
return inv_freq
|
|
187
|
+
|
|
181
188
|
|
|
182
189
|
def eager_attention_forward(
|
|
183
190
|
module: nn.Module,
|
|
@@ -769,6 +776,31 @@ class EdgeTamVideoFeedForward(nn.Module):
|
|
|
769
776
|
return hidden_states
|
|
770
777
|
|
|
771
778
|
|
|
779
|
+
class EdgeTamVideoPositionalEmbedding(nn.Module):
|
|
780
|
+
def __init__(self, config: EdgeTamVideoPromptEncoderConfig):
|
|
781
|
+
super().__init__()
|
|
782
|
+
self.scale = config.scale
|
|
783
|
+
positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
|
|
784
|
+
self.register_buffer("positional_embedding", positional_embedding)
|
|
785
|
+
|
|
786
|
+
def forward(self, input_coords, input_shape=None):
|
|
787
|
+
"""Positionally encode points that are normalized to [0,1]."""
|
|
788
|
+
coordinates = input_coords.clone()
|
|
789
|
+
|
|
790
|
+
if input_shape is not None:
|
|
791
|
+
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
|
|
792
|
+
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
|
|
793
|
+
coordinates.to(torch.float32)
|
|
794
|
+
|
|
795
|
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
|
796
|
+
coordinates = 2 * coordinates - 1
|
|
797
|
+
coordinates = coordinates.to(self.positional_embedding.dtype)
|
|
798
|
+
coordinates = coordinates @ self.positional_embedding
|
|
799
|
+
coordinates = 2 * np.pi * coordinates
|
|
800
|
+
# outputs d_1 x ... x d_n x channel shape
|
|
801
|
+
return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
|
|
802
|
+
|
|
803
|
+
|
|
772
804
|
@auto_docstring
|
|
773
805
|
class EdgeTamVideoPreTrainedModel(PreTrainedModel):
|
|
774
806
|
config_class = EdgeTamVideoConfig
|
|
@@ -794,6 +826,16 @@ class EdgeTamVideoPreTrainedModel(PreTrainedModel):
|
|
|
794
826
|
if isinstance(module, EdgeTamVideoMemoryFuserCXBlock):
|
|
795
827
|
if module.scale is not None:
|
|
796
828
|
init.zeros_(module.scale)
|
|
829
|
+
elif isinstance(module, EdgeTamVideoVisionRotaryEmbedding):
|
|
830
|
+
inv_freq = module.create_inv_freq()
|
|
831
|
+
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
|
|
832
|
+
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
|
|
833
|
+
elif isinstance(module, EdgeTamVideoPositionalEmbedding):
|
|
834
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
835
|
+
if isinstance(module, EdgeTamVideoVisionRotaryEmbedding):
|
|
836
|
+
inv_freq = module.create_inv_freq()
|
|
837
|
+
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
|
|
838
|
+
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
|
|
797
839
|
|
|
798
840
|
|
|
799
841
|
class EdgeTamVideoInferenceCache:
|
|
@@ -959,7 +1001,7 @@ class EdgeTamVideoInferenceSession:
|
|
|
959
1001
|
device_inputs = {}
|
|
960
1002
|
for key, value in inputs.items():
|
|
961
1003
|
if isinstance(value, torch.Tensor):
|
|
962
|
-
device_inputs[key] = value.to(self.inference_device, non_blocking=
|
|
1004
|
+
device_inputs[key] = value.to(self.inference_device, non_blocking=False)
|
|
963
1005
|
else:
|
|
964
1006
|
device_inputs[key] = value
|
|
965
1007
|
self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
|
|
@@ -1547,31 +1589,6 @@ class EdgeTamVideoSegmentationOutput(ModelOutput):
|
|
|
1547
1589
|
frame_idx: Optional[int] = None
|
|
1548
1590
|
|
|
1549
1591
|
|
|
1550
|
-
class EdgeTamVideoPositionalEmbedding(nn.Module):
|
|
1551
|
-
def __init__(self, config: EdgeTamVideoPromptEncoderConfig):
|
|
1552
|
-
super().__init__()
|
|
1553
|
-
self.scale = config.scale
|
|
1554
|
-
positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
|
|
1555
|
-
self.register_buffer("positional_embedding", positional_embedding)
|
|
1556
|
-
|
|
1557
|
-
def forward(self, input_coords, input_shape=None):
|
|
1558
|
-
"""Positionally encode points that are normalized to [0,1]."""
|
|
1559
|
-
coordinates = input_coords.clone()
|
|
1560
|
-
|
|
1561
|
-
if input_shape is not None:
|
|
1562
|
-
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
|
|
1563
|
-
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
|
|
1564
|
-
coordinates.to(torch.float32)
|
|
1565
|
-
|
|
1566
|
-
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
|
1567
|
-
coordinates = 2 * coordinates - 1
|
|
1568
|
-
coordinates = coordinates.to(self.positional_embedding.dtype)
|
|
1569
|
-
coordinates = coordinates @ self.positional_embedding
|
|
1570
|
-
coordinates = 2 * np.pi * coordinates
|
|
1571
|
-
# outputs d_1 x ... x d_n x channel shape
|
|
1572
|
-
return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
1592
|
class EdgeTamVideoMaskEmbedding(nn.Module):
|
|
1576
1593
|
def __init__(self, config: EdgeTamVideoPromptEncoderConfig):
|
|
1577
1594
|
super().__init__()
|
|
@@ -1976,11 +1993,6 @@ class EdgeTamVideoModel(EdgeTamVideoPreTrainedModel):
|
|
|
1976
1993
|
input_modalities = ("video", "text")
|
|
1977
1994
|
_can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamVideoTwoWayAttentionBlock, index=2)}
|
|
1978
1995
|
_keys_to_ignore_on_load_unexpected = []
|
|
1979
|
-
_tied_weights_keys = {
|
|
1980
|
-
"prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding"
|
|
1981
|
-
}
|
|
1982
|
-
# need to be ignored, as it's a buffer and will not be correctly detected as tied weight
|
|
1983
|
-
_keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
|
|
1984
1996
|
|
|
1985
1997
|
def __init__(self, config: EdgeTamVideoConfig):
|
|
1986
1998
|
super().__init__(config)
|
|
@@ -29,6 +29,7 @@ from transformers.models.sam2.modeling_sam2 import (
|
|
|
29
29
|
)
|
|
30
30
|
from transformers.utils.generic import OutputRecorder
|
|
31
31
|
|
|
32
|
+
from ... import initialization as init
|
|
32
33
|
from ...activations import ACT2FN
|
|
33
34
|
from ...configuration_utils import PreTrainedConfig
|
|
34
35
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
@@ -375,24 +376,17 @@ class EdgeTamVideoVisionEncoderOutput(Sam2VideoVisionEncoderOutput):
|
|
|
375
376
|
class EdgeTamVideoVisionRotaryEmbedding(Sam2VideoVisionRotaryEmbedding):
|
|
376
377
|
def __init__(self, config: EdgeTamVideoConfig, end_x: Optional[int] = None, end_y: Optional[int] = None):
|
|
377
378
|
nn.Module.__init__()
|
|
378
|
-
dim = config.memory_attention_hidden_size // (
|
|
379
|
+
self.dim = config.memory_attention_hidden_size // (
|
|
379
380
|
config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
|
|
380
381
|
)
|
|
381
382
|
# Ensure even dimension for proper axial splitting
|
|
382
|
-
if dim % 4 != 0:
|
|
383
|
+
if self.dim % 4 != 0:
|
|
383
384
|
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
|
|
384
|
-
end_x, end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
# Generate 2D position indices for axial rotary embedding
|
|
388
|
-
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
389
|
-
x_positions = flattened_indices % end_x
|
|
390
|
-
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
|
|
391
|
-
freqs_x = torch.outer(x_positions, freqs).float()
|
|
392
|
-
freqs_y = torch.outer(y_positions, freqs).float()
|
|
393
|
-
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
394
|
-
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
385
|
+
self.end_x, self.end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
|
|
386
|
+
self.memory_attention_rope_theta = config.memory_attention_rope_theta
|
|
387
|
+
|
|
395
388
|
# directly register the cos and sin embeddings as we have a fixed feature shape
|
|
389
|
+
inv_freq = self.create_inv_freq()
|
|
396
390
|
self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
|
|
397
391
|
self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
|
|
398
392
|
|
|
@@ -662,7 +656,12 @@ class EdgeTamVideoFeedForward(Sam2VideoFeedForward):
|
|
|
662
656
|
|
|
663
657
|
|
|
664
658
|
class EdgeTamVideoPreTrainedModel(Sam2VideoPreTrainedModel):
|
|
665
|
-
|
|
659
|
+
def _init_weights(self, module):
|
|
660
|
+
super()._init_weights()
|
|
661
|
+
if isinstance(module, EdgeTamVideoVisionRotaryEmbedding):
|
|
662
|
+
inv_freq = module.create_inv_freq()
|
|
663
|
+
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
|
|
664
|
+
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
|
|
666
665
|
|
|
667
666
|
|
|
668
667
|
class EdgeTamVideoInferenceSession(Sam2VideoInferenceSession):
|
|
@@ -1040,11 +1039,6 @@ class EdgeTamVideoSegmentationOutput(Sam2VideoSegmentationOutput):
|
|
|
1040
1039
|
|
|
1041
1040
|
@auto_docstring
|
|
1042
1041
|
class EdgeTamVideoModel(Sam2VideoModel):
|
|
1043
|
-
_tied_weights_keys = {
|
|
1044
|
-
"prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding"
|
|
1045
|
-
}
|
|
1046
|
-
# need to be ignored, as it's a buffer and will not be correctly detected as tied weight
|
|
1047
|
-
_keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
|
|
1048
1042
|
_keys_to_ignore_on_load_unexpected = []
|
|
1049
1043
|
_can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamVideoTwoWayAttentionBlock, index=2)}
|
|
1050
1044
|
|
|
@@ -153,9 +153,8 @@ class EfficientLoFTRImageProcessorFast(BaseImageProcessorFast):
|
|
|
153
153
|
stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]
|
|
154
154
|
|
|
155
155
|
# Return in same format as slow processor
|
|
156
|
-
image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs
|
|
157
156
|
|
|
158
|
-
return BatchFeature(data={"pixel_values":
|
|
157
|
+
return BatchFeature(data={"pixel_values": stacked_pairs}, tensor_type=return_tensors)
|
|
159
158
|
|
|
160
159
|
def post_process_keypoint_matching(
|
|
161
160
|
self,
|
|
@@ -103,7 +103,7 @@ class EfficientLoFTRRotaryEmbedding(nn.Module):
|
|
|
103
103
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
104
104
|
|
|
105
105
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
106
|
-
self.original_inv_freq =
|
|
106
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
107
107
|
|
|
108
108
|
@staticmethod
|
|
109
109
|
# Ignore copy
|
|
@@ -684,9 +684,22 @@ class EfficientLoFTRPreTrainedModel(PreTrainedModel):
|
|
|
684
684
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
685
685
|
if module.bias is not None:
|
|
686
686
|
init.zeros_(module.bias)
|
|
687
|
+
if getattr(module, "running_mean", None) is not None:
|
|
688
|
+
init.zeros_(module.running_mean)
|
|
689
|
+
init.ones_(module.running_var)
|
|
690
|
+
init.zeros_(module.num_batches_tracked)
|
|
687
691
|
elif isinstance(module, nn.LayerNorm):
|
|
688
692
|
init.zeros_(module.bias)
|
|
689
693
|
init.ones_(module.weight)
|
|
694
|
+
elif isinstance(module, EfficientLoFTRRotaryEmbedding):
|
|
695
|
+
rope_fn = (
|
|
696
|
+
ROPE_INIT_FUNCTIONS[module.rope_type]
|
|
697
|
+
if module.rope_type != "default"
|
|
698
|
+
else module.compute_default_rope_parameters
|
|
699
|
+
)
|
|
700
|
+
buffer_value, _ = rope_fn(module.config)
|
|
701
|
+
init.copy_(module.inv_freq, buffer_value)
|
|
702
|
+
init.copy_(module.original_inv_freq, buffer_value)
|
|
690
703
|
|
|
691
704
|
# Copied from transformers.models.superpoint.modeling_superpoint.SuperPointPreTrainedModel.extract_one_channel_pixel_values with SuperPoint->EfficientLoFTR
|
|
692
705
|
def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
|
|
@@ -66,7 +66,7 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
|
|
66
66
|
`do_resize` in `preprocess`.
|
|
67
67
|
size (`dict[str, int]` *optional*, defaults to `{"height": 346, "width": 346}`):
|
|
68
68
|
Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
|
|
69
|
-
resample (`PILImageResampling` filter, *optional*, defaults to
|
|
69
|
+
resample (`PILImageResampling` filter, *optional*, defaults to `Resampling.BICUBIC`):
|
|
70
70
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
|
|
71
71
|
do_center_crop (`bool`, *optional*, defaults to `False`):
|
|
72
72
|
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
|
|
@@ -102,7 +102,7 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
|
|
102
102
|
self,
|
|
103
103
|
do_resize: bool = True,
|
|
104
104
|
size: Optional[dict[str, int]] = None,
|
|
105
|
-
resample: PILImageResampling =
|
|
105
|
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
|
106
106
|
do_center_crop: bool = False,
|
|
107
107
|
crop_size: Optional[dict[str, int]] = None,
|
|
108
108
|
rescale_factor: Union[int, float] = 1 / 255,
|
|
@@ -133,12 +133,11 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
|
|
133
133
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
|
134
134
|
self.include_top = include_top
|
|
135
135
|
|
|
136
|
-
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.NEAREST
|
|
137
136
|
def resize(
|
|
138
137
|
self,
|
|
139
138
|
image: np.ndarray,
|
|
140
139
|
size: dict[str, int],
|
|
141
|
-
resample: PILImageResampling = PILImageResampling.
|
|
140
|
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
|
142
141
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
|
143
142
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
|
144
143
|
**kwargs,
|
|
@@ -151,8 +150,8 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
|
|
151
150
|
Image to resize.
|
|
152
151
|
size (`dict[str, int]`):
|
|
153
152
|
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
|
154
|
-
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.
|
|
155
|
-
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.
|
|
153
|
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
|
154
|
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
|
|
156
155
|
data_format (`ChannelDimension` or `str`, *optional*):
|
|
157
156
|
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
|
158
157
|
image is used. Can be one of:
|
|
@@ -33,7 +33,7 @@ from .image_processing_efficientnet import EfficientNetImageProcessorKwargs
|
|
|
33
33
|
|
|
34
34
|
@auto_docstring
|
|
35
35
|
class EfficientNetImageProcessorFast(BaseImageProcessorFast):
|
|
36
|
-
resample = PILImageResampling.
|
|
36
|
+
resample = PILImageResampling.BICUBIC
|
|
37
37
|
image_mean = IMAGENET_STANDARD_MEAN
|
|
38
38
|
image_std = IMAGENET_STANDARD_STD
|
|
39
39
|
size = {"height": 346, "width": 346}
|
|
@@ -178,7 +178,6 @@ class EfficientNetImageProcessorFast(BaseImageProcessorFast):
|
|
|
178
178
|
processed_images_grouped[shape] = stacked_images
|
|
179
179
|
|
|
180
180
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
181
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
182
181
|
|
|
183
182
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
184
183
|
|
|
@@ -435,7 +435,7 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
|
|
|
435
435
|
base_model_prefix = "efficientnet"
|
|
436
436
|
main_input_name = "pixel_values"
|
|
437
437
|
input_modalities = ("image",)
|
|
438
|
-
_no_split_modules = []
|
|
438
|
+
_no_split_modules = ["EfficientNetBlock"]
|
|
439
439
|
|
|
440
440
|
@torch.no_grad()
|
|
441
441
|
def _init_weights(self, module: nn.Module):
|
|
@@ -444,6 +444,10 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
|
|
|
444
444
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
445
445
|
if module.bias is not None:
|
|
446
446
|
init.zeros_(module.bias)
|
|
447
|
+
if getattr(module, "running_mean", None) is not None:
|
|
448
|
+
init.zeros_(module.running_mean)
|
|
449
|
+
init.ones_(module.running_var)
|
|
450
|
+
init.zeros_(module.num_batches_tracked)
|
|
447
451
|
|
|
448
452
|
|
|
449
453
|
@auto_docstring
|
|
@@ -22,6 +22,7 @@ import torch
|
|
|
22
22
|
from torch import nn
|
|
23
23
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
24
24
|
|
|
25
|
+
from ... import initialization as init
|
|
25
26
|
from ...activations import ACT2FN, get_activation
|
|
26
27
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
27
28
|
from ...generation import GenerationMixin
|
|
@@ -532,6 +533,12 @@ class ElectraPreTrainedModel(PreTrainedModel):
|
|
|
532
533
|
"cross_attentions": ElectraCrossAttention,
|
|
533
534
|
}
|
|
534
535
|
|
|
536
|
+
def _init_weights(self, module):
|
|
537
|
+
super()._init_weights(module)
|
|
538
|
+
if isinstance(module, ElectraEmbeddings):
|
|
539
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
540
|
+
init.zeros_(module.token_type_ids)
|
|
541
|
+
|
|
535
542
|
|
|
536
543
|
@dataclass
|
|
537
544
|
@auto_docstring(
|
|
@@ -958,6 +958,10 @@ class Emu3VQVAE(PreTrainedModel):
|
|
|
958
958
|
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
|
|
959
959
|
init.constant_(module.weight, 1.0)
|
|
960
960
|
init.constant_(module.bias, 0.0)
|
|
961
|
+
if getattr(module, "running_mean", None) is not None:
|
|
962
|
+
init.zeros_(module.running_mean)
|
|
963
|
+
init.ones_(module.running_var)
|
|
964
|
+
init.zeros_(module.num_batches_tracked)
|
|
961
965
|
elif isinstance(module, nn.Embedding):
|
|
962
966
|
init.normal_(module.weight)
|
|
963
967
|
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
@@ -1128,7 +1132,7 @@ class Emu3RotaryEmbedding(nn.Module):
|
|
|
1128
1132
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
1129
1133
|
|
|
1130
1134
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
1131
|
-
self.original_inv_freq =
|
|
1135
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
1132
1136
|
|
|
1133
1137
|
@staticmethod
|
|
1134
1138
|
def compute_default_rope_parameters(
|
|
@@ -1615,6 +1619,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
|
|
1615
1619
|
position_ids=None,
|
|
1616
1620
|
use_cache=True,
|
|
1617
1621
|
pixel_values=None,
|
|
1622
|
+
is_first_iteration=False,
|
|
1618
1623
|
**kwargs,
|
|
1619
1624
|
):
|
|
1620
1625
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -1628,10 +1633,11 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
|
|
1628
1633
|
position_ids=position_ids,
|
|
1629
1634
|
pixel_values=pixel_values,
|
|
1630
1635
|
use_cache=use_cache,
|
|
1636
|
+
is_first_iteration=is_first_iteration,
|
|
1631
1637
|
**kwargs,
|
|
1632
1638
|
)
|
|
1633
1639
|
|
|
1634
|
-
if
|
|
1640
|
+
if not is_first_iteration and use_cache:
|
|
1635
1641
|
model_inputs["pixel_values"] = None
|
|
1636
1642
|
|
|
1637
1643
|
return model_inputs
|
|
@@ -706,6 +706,10 @@ class Emu3VQVAE(PreTrainedModel):
|
|
|
706
706
|
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
|
|
707
707
|
init.constant_(module.weight, 1.0)
|
|
708
708
|
init.constant_(module.bias, 0.0)
|
|
709
|
+
if getattr(module, "running_mean", None) is not None:
|
|
710
|
+
init.zeros_(module.running_mean)
|
|
711
|
+
init.ones_(module.running_var)
|
|
712
|
+
init.zeros_(module.num_batches_tracked)
|
|
709
713
|
elif isinstance(module, nn.Embedding):
|
|
710
714
|
init.normal_(module.weight)
|
|
711
715
|
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
@@ -1167,6 +1171,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
|
|
1167
1171
|
position_ids=None,
|
|
1168
1172
|
use_cache=True,
|
|
1169
1173
|
pixel_values=None,
|
|
1174
|
+
is_first_iteration=False,
|
|
1170
1175
|
**kwargs,
|
|
1171
1176
|
):
|
|
1172
1177
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -1180,10 +1185,11 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
|
|
1180
1185
|
position_ids=position_ids,
|
|
1181
1186
|
pixel_values=pixel_values,
|
|
1182
1187
|
use_cache=use_cache,
|
|
1188
|
+
is_first_iteration=is_first_iteration,
|
|
1183
1189
|
**kwargs,
|
|
1184
1190
|
)
|
|
1185
1191
|
|
|
1186
|
-
if
|
|
1192
|
+
if not is_first_iteration and use_cache:
|
|
1187
1193
|
model_inputs["pixel_values"] = None
|
|
1188
1194
|
|
|
1189
1195
|
return model_inputs
|