transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -262,9 +262,14 @@ class ResNetPreTrainedModel(PreTrainedModel):
|
|
|
262
262
|
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
|
|
263
263
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
264
264
|
init.uniform_(module.bias, -bound, bound)
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
init.
|
|
265
|
+
# We need to check it like that as some Detr models replace the BatchNorm2d by their own
|
|
266
|
+
elif "BatchNorm" in module.__class__.__name__:
|
|
267
|
+
init.ones_(module.weight)
|
|
268
|
+
init.zeros_(module.bias)
|
|
269
|
+
init.zeros_(module.running_mean)
|
|
270
|
+
init.ones_(module.running_var)
|
|
271
|
+
if getattr(module, "num_batches_tracked", None) is not None:
|
|
272
|
+
init.zeros_(module.num_batches_tracked)
|
|
268
273
|
|
|
269
274
|
|
|
270
275
|
@auto_docstring
|
|
@@ -501,6 +501,9 @@ class RobertaPreTrainedModel(PreTrainedModel):
|
|
|
501
501
|
super()._init_weights(module)
|
|
502
502
|
if isinstance(module, RobertaLMHead):
|
|
503
503
|
init.zeros_(module.bias)
|
|
504
|
+
elif isinstance(module, RobertaEmbeddings):
|
|
505
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
506
|
+
init.zeros_(module.token_type_ids)
|
|
504
507
|
|
|
505
508
|
|
|
506
509
|
class RobertaEncoder(nn.Module):
|
|
@@ -172,6 +172,9 @@ class RobertaPreTrainedModel(PreTrainedModel):
|
|
|
172
172
|
super()._init_weights(module)
|
|
173
173
|
if isinstance(module, RobertaLMHead):
|
|
174
174
|
init.zeros_(module.bias)
|
|
175
|
+
elif isinstance(module, RobertaEmbeddings):
|
|
176
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
177
|
+
init.zeros_(module.token_type_ids)
|
|
175
178
|
|
|
176
179
|
|
|
177
180
|
class RobertaModel(BertModel):
|
|
@@ -561,6 +561,9 @@ class RobertaPreLayerNormPreTrainedModel(PreTrainedModel):
|
|
|
561
561
|
super()._init_weights(module)
|
|
562
562
|
if isinstance(module, RobertaPreLayerNormLMHead):
|
|
563
563
|
init.zeros_(module.bias)
|
|
564
|
+
elif isinstance(module, RobertaPreLayerNormEmbeddings):
|
|
565
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
566
|
+
init.zeros_(module.token_type_ids)
|
|
564
567
|
|
|
565
568
|
|
|
566
569
|
@auto_docstring(
|
|
@@ -621,6 +621,9 @@ class RoCBertPreTrainedModel(PreTrainedModel):
|
|
|
621
621
|
super()._init_weights(module)
|
|
622
622
|
if isinstance(module, RoCBertLMPredictionHead):
|
|
623
623
|
init.zeros_(module.bias)
|
|
624
|
+
elif isinstance(module, RoCBertEmbeddings):
|
|
625
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
626
|
+
init.zeros_(module.token_type_ids)
|
|
624
627
|
|
|
625
628
|
|
|
626
629
|
@auto_docstring(
|
|
@@ -44,7 +44,7 @@ class RTDetrConfig(PreTrainedConfig):
|
|
|
44
44
|
The epsilon used by the layer normalization layers.
|
|
45
45
|
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
|
46
46
|
The epsilon used by the batch normalization layers.
|
|
47
|
-
backbone_config (`
|
|
47
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `RTDetrResNetConfig()`):
|
|
48
48
|
The configuration of the backbone model.
|
|
49
49
|
backbone (`str`, *optional*):
|
|
50
50
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -1059,6 +1059,10 @@ class RTDetrPreTrainedModel(PreTrainedModel):
|
|
|
1059
1059
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
1060
1060
|
if module.bias is not None:
|
|
1061
1061
|
init.zeros_(module.bias)
|
|
1062
|
+
if getattr(module, "running_mean", None) is not None:
|
|
1063
|
+
init.zeros_(module.running_mean)
|
|
1064
|
+
init.ones_(module.running_var)
|
|
1065
|
+
init.zeros_(module.num_batches_tracked)
|
|
1062
1066
|
|
|
1063
1067
|
elif isinstance(module, nn.LayerNorm):
|
|
1064
1068
|
init.ones_(module.weight)
|
|
@@ -316,9 +316,14 @@ class RTDetrResNetPreTrainedModel(PreTrainedModel):
|
|
|
316
316
|
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
|
|
317
317
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
318
318
|
init.uniform_(module.bias, -bound, bound)
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
init.
|
|
319
|
+
# We need to check it like that as some Detr models replace the BatchNorm2d by their own
|
|
320
|
+
elif "BatchNorm" in module.__class__.__name__:
|
|
321
|
+
init.ones_(module.weight)
|
|
322
|
+
init.zeros_(module.bias)
|
|
323
|
+
init.zeros_(module.running_mean)
|
|
324
|
+
init.ones_(module.running_var)
|
|
325
|
+
if getattr(module, "num_batches_tracked", None) is not None:
|
|
326
|
+
init.zeros_(module.num_batches_tracked)
|
|
322
327
|
|
|
323
328
|
|
|
324
329
|
@auto_docstring(
|
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
19
19
|
# See the License for the specific language governing permissions and
|
|
20
20
|
# limitations under the License.
|
|
21
|
-
|
|
22
21
|
from ...configuration_utils import PreTrainedConfig
|
|
23
22
|
from ...utils import logging
|
|
24
23
|
from ...utils.backbone_utils import verify_backbone_config_arguments
|
|
@@ -49,7 +48,7 @@ class RTDetrV2Config(PreTrainedConfig):
|
|
|
49
48
|
The epsilon used by the layer normalization layers.
|
|
50
49
|
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
|
51
50
|
The epsilon used by the batch normalization layers.
|
|
52
|
-
backbone_config (`
|
|
51
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `RTDetrV2ResNetConfig()`):
|
|
53
52
|
The configuration of the backbone model.
|
|
54
53
|
backbone (`str`, *optional*):
|
|
55
54
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -357,8 +356,8 @@ class RTDetrV2Config(PreTrainedConfig):
|
|
|
357
356
|
self.decoder_n_levels = decoder_n_levels
|
|
358
357
|
self.decoder_offset_scale = decoder_offset_scale
|
|
359
358
|
self.decoder_method = decoder_method
|
|
359
|
+
|
|
360
360
|
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
|
361
|
-
self.tie_encoder_decoder = True
|
|
362
361
|
|
|
363
362
|
|
|
364
363
|
__all__ = ["RTDetrV2Config"]
|
|
@@ -506,6 +506,10 @@ class RTDetrV2PreTrainedModel(PreTrainedModel):
|
|
|
506
506
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
507
507
|
if module.bias is not None:
|
|
508
508
|
init.zeros_(module.bias)
|
|
509
|
+
if getattr(module, "running_mean", None) is not None:
|
|
510
|
+
init.zeros_(module.running_mean)
|
|
511
|
+
init.ones_(module.running_var)
|
|
512
|
+
init.zeros_(module.num_batches_tracked)
|
|
509
513
|
|
|
510
514
|
elif isinstance(module, nn.LayerNorm):
|
|
511
515
|
init.ones_(module.weight)
|
|
@@ -515,6 +519,9 @@ class RTDetrV2PreTrainedModel(PreTrainedModel):
|
|
|
515
519
|
init.xavier_uniform_(module.weight_embedding.weight)
|
|
516
520
|
if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
|
|
517
521
|
init.xavier_uniform_(module.denoising_class_embed.weight)
|
|
522
|
+
if isinstance(module, RTDetrV2MultiscaleDeformableAttention):
|
|
523
|
+
n_points_scale = [1 / n for n in module.n_points_list for _ in range(n)]
|
|
524
|
+
init.copy_(module.n_points_scale, torch.tensor(n_points_scale, dtype=torch.float32))
|
|
518
525
|
|
|
519
526
|
|
|
520
527
|
@dataclass
|
|
@@ -19,6 +19,7 @@ import torch
|
|
|
19
19
|
import torch.nn.functional as F
|
|
20
20
|
from torch import Tensor, nn
|
|
21
21
|
|
|
22
|
+
from ... import initialization as init
|
|
22
23
|
from ...configuration_utils import PreTrainedConfig
|
|
23
24
|
from ...utils import is_torchdynamo_compiling, logging
|
|
24
25
|
from ...utils.backbone_utils import (
|
|
@@ -59,7 +60,7 @@ class RTDetrV2Config(PreTrainedConfig):
|
|
|
59
60
|
The epsilon used by the layer normalization layers.
|
|
60
61
|
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
|
61
62
|
The epsilon used by the batch normalization layers.
|
|
62
|
-
backbone_config (`
|
|
63
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `RTDetrV2ResNetConfig()`):
|
|
63
64
|
The configuration of the backbone model.
|
|
64
65
|
backbone (`str`, *optional*):
|
|
65
66
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -367,8 +368,8 @@ class RTDetrV2Config(PreTrainedConfig):
|
|
|
367
368
|
self.decoder_n_levels = decoder_n_levels
|
|
368
369
|
self.decoder_offset_scale = decoder_offset_scale
|
|
369
370
|
self.decoder_method = decoder_method
|
|
371
|
+
|
|
370
372
|
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
|
371
|
-
self.tie_encoder_decoder = True
|
|
372
373
|
|
|
373
374
|
|
|
374
375
|
def multi_scale_deformable_attention_v2(
|
|
@@ -564,7 +565,11 @@ class RTDetrV2DecoderLayer(RTDetrDecoderLayer):
|
|
|
564
565
|
|
|
565
566
|
|
|
566
567
|
class RTDetrV2PreTrainedModel(RTDetrPreTrainedModel):
|
|
567
|
-
|
|
568
|
+
def _init_weights(self, module):
|
|
569
|
+
super()._init_weights(module)
|
|
570
|
+
if isinstance(module, RTDetrV2MultiscaleDeformableAttention):
|
|
571
|
+
n_points_scale = [1 / n for n in module.n_points_list for _ in range(n)]
|
|
572
|
+
init.copy_(module.n_points_scale, torch.tensor(n_points_scale, dtype=torch.float32))
|
|
568
573
|
|
|
569
574
|
|
|
570
575
|
class RTDetrV2Decoder(RTDetrDecoder):
|
|
@@ -49,7 +49,7 @@ def load_wkv_cuda_kernel(context_length):
|
|
|
49
49
|
if not is_kernels_available():
|
|
50
50
|
raise ImportError("kernels is not installed, please install it with `pip install kernels`")
|
|
51
51
|
|
|
52
|
-
from
|
|
52
|
+
from ...integrations.hub_kernels import get_kernel
|
|
53
53
|
|
|
54
54
|
rwkv_cuda_kernel = get_kernel("kernels-community/rwkv")
|
|
55
55
|
rwkv_cuda_kernel.max_seq_length = context_length
|
|
@@ -249,6 +249,7 @@ class SamVisionConfig(PreTrainedConfig):
|
|
|
249
249
|
self.global_attn_indexes = global_attn_indexes
|
|
250
250
|
self.num_pos_feats = num_pos_feats
|
|
251
251
|
self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim
|
|
252
|
+
self.scale = self.hidden_size // 2
|
|
252
253
|
|
|
253
254
|
|
|
254
255
|
class SamConfig(PreTrainedConfig):
|
|
@@ -267,7 +267,6 @@ class SamImageProcessorFast(BaseImageProcessorFast):
|
|
|
267
267
|
if do_pad:
|
|
268
268
|
processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
|
|
269
269
|
|
|
270
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
271
270
|
return BatchFeature(
|
|
272
271
|
data={"pixel_values": processed_images, "reshaped_input_sizes": reshaped_input_sizes},
|
|
273
272
|
tensor_type=return_tensors,
|
|
@@ -548,7 +548,7 @@ class SamMaskDecoder(nn.Module):
|
|
|
548
548
|
class SamPositionalEmbedding(nn.Module):
|
|
549
549
|
def __init__(self, config):
|
|
550
550
|
super().__init__()
|
|
551
|
-
self.scale = config.
|
|
551
|
+
self.scale = config.scale
|
|
552
552
|
self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
|
|
553
553
|
|
|
554
554
|
def forward(self, input_coords, input_shape=None):
|
|
@@ -1014,6 +1014,8 @@ class SamPreTrainedModel(PreTrainedModel):
|
|
|
1014
1014
|
elif isinstance(module, SamVisionEncoder):
|
|
1015
1015
|
if self.config.use_abs_pos:
|
|
1016
1016
|
init.zeros_(module.pos_embed)
|
|
1017
|
+
elif isinstance(module, SamPositionalEmbedding):
|
|
1018
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
1017
1019
|
|
|
1018
1020
|
|
|
1019
1021
|
class SamVisionEncoder(SamPreTrainedModel):
|
|
@@ -1048,6 +1050,7 @@ class SamVisionEncoder(SamPreTrainedModel):
|
|
|
1048
1050
|
self.neck = SamVisionNeck(config)
|
|
1049
1051
|
|
|
1050
1052
|
self.gradient_checkpointing = False
|
|
1053
|
+
self.post_init()
|
|
1051
1054
|
|
|
1052
1055
|
def get_input_embeddings(self):
|
|
1053
1056
|
return self.patch_embed
|
|
@@ -152,7 +152,7 @@ class Sam2VisionConfig(PreTrainedConfig):
|
|
|
152
152
|
documentation from [`PreTrainedConfig`] for more information.
|
|
153
153
|
|
|
154
154
|
Args:
|
|
155
|
-
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional
|
|
155
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `Sam2HieraDetConfig()`):
|
|
156
156
|
Configuration for the vision backbone. This is used to instantiate the backbone using
|
|
157
157
|
`AutoModel.from_config`.
|
|
158
158
|
backbone_channel_list (`List[int]`, *optional*, defaults to `[768, 384, 192, 96]`):
|
|
@@ -565,7 +565,9 @@ class Sam2PreTrainedModel(PreTrainedModel):
|
|
|
565
565
|
init.zeros_(module.pos_embed)
|
|
566
566
|
if module.pos_embed_window is not None:
|
|
567
567
|
init.zeros_(module.pos_embed_window)
|
|
568
|
-
|
|
568
|
+
elif isinstance(module, Sam2PositionalEmbedding):
|
|
569
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
570
|
+
elif isinstance(module, Sam2Model):
|
|
569
571
|
if module.no_memory_embedding is not None:
|
|
570
572
|
init.zeros_(module.no_memory_embedding)
|
|
571
573
|
|
|
@@ -600,6 +602,8 @@ class Sam2HieraDetModel(Sam2PreTrainedModel):
|
|
|
600
602
|
self.blocks.append(block)
|
|
601
603
|
total_block_idx += 1
|
|
602
604
|
|
|
605
|
+
self.post_init()
|
|
606
|
+
|
|
603
607
|
def get_input_embeddings(self):
|
|
604
608
|
return self.patch_embed
|
|
605
609
|
|
|
@@ -681,7 +681,9 @@ class Sam2PreTrainedModel(PreTrainedModel):
|
|
|
681
681
|
init.zeros_(module.pos_embed)
|
|
682
682
|
if module.pos_embed_window is not None:
|
|
683
683
|
init.zeros_(module.pos_embed_window)
|
|
684
|
-
|
|
684
|
+
elif isinstance(module, Sam2PositionalEmbedding):
|
|
685
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
686
|
+
elif isinstance(module, Sam2Model):
|
|
685
687
|
if module.no_memory_embedding is not None:
|
|
686
688
|
init.zeros_(module.no_memory_embedding)
|
|
687
689
|
|
|
@@ -716,6 +718,8 @@ class Sam2HieraDetModel(Sam2PreTrainedModel):
|
|
|
716
718
|
self.blocks.append(block)
|
|
717
719
|
total_block_idx += 1
|
|
718
720
|
|
|
721
|
+
self.post_init()
|
|
722
|
+
|
|
719
723
|
def get_input_embeddings(self):
|
|
720
724
|
return self.patch_embed
|
|
721
725
|
|
|
@@ -209,7 +209,7 @@ class Sam2VideoInferenceSession:
|
|
|
209
209
|
device_inputs = {}
|
|
210
210
|
for key, value in inputs.items():
|
|
211
211
|
if isinstance(value, torch.Tensor):
|
|
212
|
-
device_inputs[key] = value.to(self.inference_device, non_blocking=
|
|
212
|
+
device_inputs[key] = value.to(self.inference_device, non_blocking=False)
|
|
213
213
|
else:
|
|
214
214
|
device_inputs[key] = value
|
|
215
215
|
self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
|
|
@@ -688,6 +688,12 @@ class Sam2VideoPreTrainedModel(PreTrainedModel):
|
|
|
688
688
|
if isinstance(module, Sam2VideoMemoryFuserCXBlock):
|
|
689
689
|
if module.scale is not None:
|
|
690
690
|
init.zeros_(module.scale)
|
|
691
|
+
elif isinstance(module, Sam2VideoVisionRotaryEmbedding):
|
|
692
|
+
inv_freq = module.create_inv_freq()
|
|
693
|
+
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
|
|
694
|
+
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
|
|
695
|
+
elif isinstance(module, Sam2VideoPositionalEmbedding):
|
|
696
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
691
697
|
|
|
692
698
|
|
|
693
699
|
class Sam2VideoVisionRotaryEmbedding(nn.Module):
|
|
@@ -698,24 +704,17 @@ class Sam2VideoVisionRotaryEmbedding(nn.Module):
|
|
|
698
704
|
|
|
699
705
|
def __init__(self, config: Sam2VideoConfig):
|
|
700
706
|
super().__init__()
|
|
701
|
-
dim = config.memory_attention_hidden_size // (
|
|
707
|
+
self.dim = config.memory_attention_hidden_size // (
|
|
702
708
|
config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
|
|
703
709
|
)
|
|
704
710
|
# Ensure even dimension for proper axial splitting
|
|
705
|
-
if dim % 4 != 0:
|
|
711
|
+
if self.dim % 4 != 0:
|
|
706
712
|
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
|
|
707
|
-
end_x, end_y = config.memory_attention_rope_feat_sizes
|
|
708
|
-
|
|
713
|
+
self.end_x, self.end_y = config.memory_attention_rope_feat_sizes
|
|
714
|
+
self.memory_attention_rope_theta = config.memory_attention_rope_theta
|
|
709
715
|
|
|
710
|
-
# Generate 2D position indices for axial rotary embedding
|
|
711
|
-
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
712
|
-
x_positions = flattened_indices % end_x
|
|
713
|
-
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
|
|
714
|
-
freqs_x = torch.outer(x_positions, freqs).float()
|
|
715
|
-
freqs_y = torch.outer(y_positions, freqs).float()
|
|
716
|
-
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
717
|
-
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
718
716
|
# directly register the cos and sin embeddings as we have a fixed feature shape
|
|
717
|
+
inv_freq = self.create_inv_freq()
|
|
719
718
|
self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
|
|
720
719
|
self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
|
|
721
720
|
|
|
@@ -724,6 +723,20 @@ class Sam2VideoVisionRotaryEmbedding(nn.Module):
|
|
|
724
723
|
# As the feature map size is fixed, we can just return the pre-computed embeddings.
|
|
725
724
|
return self.rope_embeddings_cos, self.rope_embeddings_sin
|
|
726
725
|
|
|
726
|
+
def create_inv_freq(self):
|
|
727
|
+
freqs = 1.0 / (
|
|
728
|
+
self.memory_attention_rope_theta ** (torch.arange(0, self.dim, 4)[: (self.dim // 4)].float() / self.dim)
|
|
729
|
+
)
|
|
730
|
+
# Generate 2D position indices for axial rotary embedding
|
|
731
|
+
flattened_indices = torch.arange(self.end_x * self.end_y, dtype=torch.long)
|
|
732
|
+
x_positions = flattened_indices % self.end_x
|
|
733
|
+
y_positions = torch.div(flattened_indices, self.end_x, rounding_mode="floor")
|
|
734
|
+
freqs_x = torch.outer(x_positions, freqs).float()
|
|
735
|
+
freqs_y = torch.outer(y_positions, freqs).float()
|
|
736
|
+
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
737
|
+
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
738
|
+
return inv_freq
|
|
739
|
+
|
|
727
740
|
|
|
728
741
|
def rotate_pairwise(x):
|
|
729
742
|
"""
|
|
@@ -1101,6 +1114,31 @@ class Sam2VideoMemoryEncoder(nn.Module):
|
|
|
1101
1114
|
return vision_features, vision_pos_enc
|
|
1102
1115
|
|
|
1103
1116
|
|
|
1117
|
+
class Sam2VideoPositionalEmbedding(nn.Module):
|
|
1118
|
+
def __init__(self, config: Sam2VideoPromptEncoderConfig):
|
|
1119
|
+
super().__init__()
|
|
1120
|
+
self.scale = config.scale
|
|
1121
|
+
positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
|
|
1122
|
+
self.register_buffer("positional_embedding", positional_embedding)
|
|
1123
|
+
|
|
1124
|
+
def forward(self, input_coords, input_shape=None):
|
|
1125
|
+
"""Positionally encode points that are normalized to [0,1]."""
|
|
1126
|
+
coordinates = input_coords.clone()
|
|
1127
|
+
|
|
1128
|
+
if input_shape is not None:
|
|
1129
|
+
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
|
|
1130
|
+
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
|
|
1131
|
+
coordinates.to(torch.float32)
|
|
1132
|
+
|
|
1133
|
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
|
1134
|
+
coordinates = 2 * coordinates - 1
|
|
1135
|
+
coordinates = coordinates.to(self.positional_embedding.dtype)
|
|
1136
|
+
coordinates = coordinates @ self.positional_embedding
|
|
1137
|
+
coordinates = 2 * np.pi * coordinates
|
|
1138
|
+
# outputs d_1 x ... x d_n x channel shape
|
|
1139
|
+
return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
|
|
1140
|
+
|
|
1141
|
+
|
|
1104
1142
|
@dataclass
|
|
1105
1143
|
@auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
|
|
1106
1144
|
class Sam2VideoVisionEncoderOutput(ModelOutput):
|
|
@@ -1130,31 +1168,6 @@ class Sam2VideoVisionEncoderOutput(ModelOutput):
|
|
|
1130
1168
|
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
|
1131
1169
|
|
|
1132
1170
|
|
|
1133
|
-
class Sam2VideoPositionalEmbedding(nn.Module):
|
|
1134
|
-
def __init__(self, config: Sam2VideoPromptEncoderConfig):
|
|
1135
|
-
super().__init__()
|
|
1136
|
-
self.scale = config.scale
|
|
1137
|
-
positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
|
|
1138
|
-
self.register_buffer("positional_embedding", positional_embedding)
|
|
1139
|
-
|
|
1140
|
-
def forward(self, input_coords, input_shape=None):
|
|
1141
|
-
"""Positionally encode points that are normalized to [0,1]."""
|
|
1142
|
-
coordinates = input_coords.clone()
|
|
1143
|
-
|
|
1144
|
-
if input_shape is not None:
|
|
1145
|
-
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
|
|
1146
|
-
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
|
|
1147
|
-
coordinates.to(torch.float32)
|
|
1148
|
-
|
|
1149
|
-
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
|
1150
|
-
coordinates = 2 * coordinates - 1
|
|
1151
|
-
coordinates = coordinates.to(self.positional_embedding.dtype)
|
|
1152
|
-
coordinates = coordinates @ self.positional_embedding
|
|
1153
|
-
coordinates = 2 * np.pi * coordinates
|
|
1154
|
-
# outputs d_1 x ... x d_n x channel shape
|
|
1155
|
-
return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
1171
|
class Sam2VideoMaskEmbedding(nn.Module):
|
|
1159
1172
|
def __init__(self, config: Sam2VideoPromptEncoderConfig):
|
|
1160
1173
|
super().__init__()
|
|
@@ -1559,11 +1572,6 @@ class Sam2VideoModel(Sam2VideoPreTrainedModel):
|
|
|
1559
1572
|
input_modalities = ("video", "text")
|
|
1560
1573
|
_can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2VideoTwoWayAttentionBlock, index=2)}
|
|
1561
1574
|
_keys_to_ignore_on_load_unexpected = []
|
|
1562
|
-
_tied_weights_keys = {
|
|
1563
|
-
"prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding"
|
|
1564
|
-
}
|
|
1565
|
-
# need to be ignored, as it's a buffer and will not be correctly detected as tied weight
|
|
1566
|
-
_keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
|
|
1567
1575
|
|
|
1568
1576
|
def __init__(self, config: Sam2VideoConfig):
|
|
1569
1577
|
super().__init__(config)
|
|
@@ -51,6 +51,7 @@ from ..sam2.modeling_sam2 import (
|
|
|
51
51
|
Sam2ImageSegmentationOutput,
|
|
52
52
|
Sam2LayerNorm,
|
|
53
53
|
Sam2Model,
|
|
54
|
+
Sam2PositionalEmbedding,
|
|
54
55
|
Sam2SinePositionEmbedding,
|
|
55
56
|
Sam2TwoWayAttentionBlock,
|
|
56
57
|
eager_attention_forward,
|
|
@@ -477,7 +478,7 @@ class Sam2VideoInferenceSession:
|
|
|
477
478
|
device_inputs = {}
|
|
478
479
|
for key, value in inputs.items():
|
|
479
480
|
if isinstance(value, torch.Tensor):
|
|
480
|
-
device_inputs[key] = value.to(self.inference_device, non_blocking=
|
|
481
|
+
device_inputs[key] = value.to(self.inference_device, non_blocking=False)
|
|
481
482
|
else:
|
|
482
483
|
device_inputs[key] = value
|
|
483
484
|
self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
|
|
@@ -1013,6 +1014,12 @@ class Sam2VideoPreTrainedModel(PreTrainedModel):
|
|
|
1013
1014
|
if isinstance(module, Sam2VideoMemoryFuserCXBlock):
|
|
1014
1015
|
if module.scale is not None:
|
|
1015
1016
|
init.zeros_(module.scale)
|
|
1017
|
+
elif isinstance(module, Sam2VideoVisionRotaryEmbedding):
|
|
1018
|
+
inv_freq = module.create_inv_freq()
|
|
1019
|
+
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
|
|
1020
|
+
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
|
|
1021
|
+
elif isinstance(module, Sam2VideoPositionalEmbedding):
|
|
1022
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
1016
1023
|
|
|
1017
1024
|
|
|
1018
1025
|
class Sam2VideoVisionRotaryEmbedding(nn.Module):
|
|
@@ -1023,24 +1030,17 @@ class Sam2VideoVisionRotaryEmbedding(nn.Module):
|
|
|
1023
1030
|
|
|
1024
1031
|
def __init__(self, config: Sam2VideoConfig):
|
|
1025
1032
|
super().__init__()
|
|
1026
|
-
dim = config.memory_attention_hidden_size // (
|
|
1033
|
+
self.dim = config.memory_attention_hidden_size // (
|
|
1027
1034
|
config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
|
|
1028
1035
|
)
|
|
1029
1036
|
# Ensure even dimension for proper axial splitting
|
|
1030
|
-
if dim % 4 != 0:
|
|
1037
|
+
if self.dim % 4 != 0:
|
|
1031
1038
|
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
|
|
1032
|
-
end_x, end_y = config.memory_attention_rope_feat_sizes
|
|
1033
|
-
|
|
1039
|
+
self.end_x, self.end_y = config.memory_attention_rope_feat_sizes
|
|
1040
|
+
self.memory_attention_rope_theta = config.memory_attention_rope_theta
|
|
1034
1041
|
|
|
1035
|
-
# Generate 2D position indices for axial rotary embedding
|
|
1036
|
-
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
1037
|
-
x_positions = flattened_indices % end_x
|
|
1038
|
-
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
|
|
1039
|
-
freqs_x = torch.outer(x_positions, freqs).float()
|
|
1040
|
-
freqs_y = torch.outer(y_positions, freqs).float()
|
|
1041
|
-
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
1042
|
-
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
1043
1042
|
# directly register the cos and sin embeddings as we have a fixed feature shape
|
|
1043
|
+
inv_freq = self.create_inv_freq()
|
|
1044
1044
|
self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
|
|
1045
1045
|
self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
|
|
1046
1046
|
|
|
@@ -1049,6 +1049,20 @@ class Sam2VideoVisionRotaryEmbedding(nn.Module):
|
|
|
1049
1049
|
# As the feature map size is fixed, we can just return the pre-computed embeddings.
|
|
1050
1050
|
return self.rope_embeddings_cos, self.rope_embeddings_sin
|
|
1051
1051
|
|
|
1052
|
+
def create_inv_freq(self):
|
|
1053
|
+
freqs = 1.0 / (
|
|
1054
|
+
self.memory_attention_rope_theta ** (torch.arange(0, self.dim, 4)[: (self.dim // 4)].float() / self.dim)
|
|
1055
|
+
)
|
|
1056
|
+
# Generate 2D position indices for axial rotary embedding
|
|
1057
|
+
flattened_indices = torch.arange(self.end_x * self.end_y, dtype=torch.long)
|
|
1058
|
+
x_positions = flattened_indices % self.end_x
|
|
1059
|
+
y_positions = torch.div(flattened_indices, self.end_x, rounding_mode="floor")
|
|
1060
|
+
freqs_x = torch.outer(x_positions, freqs).float()
|
|
1061
|
+
freqs_y = torch.outer(y_positions, freqs).float()
|
|
1062
|
+
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
1063
|
+
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
1064
|
+
return inv_freq
|
|
1065
|
+
|
|
1052
1066
|
|
|
1053
1067
|
def rotate_pairwise(x):
|
|
1054
1068
|
"""
|
|
@@ -1426,6 +1440,10 @@ class Sam2VideoMemoryEncoder(nn.Module):
|
|
|
1426
1440
|
return vision_features, vision_pos_enc
|
|
1427
1441
|
|
|
1428
1442
|
|
|
1443
|
+
class Sam2VideoPositionalEmbedding(Sam2PositionalEmbedding):
|
|
1444
|
+
pass
|
|
1445
|
+
|
|
1446
|
+
|
|
1429
1447
|
# a large negative value as a placeholder score for missing objects
|
|
1430
1448
|
NO_OBJ_SCORE = -1024.0
|
|
1431
1449
|
|
|
@@ -1446,11 +1464,6 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
|
|
1446
1464
|
@auto_docstring
|
|
1447
1465
|
class Sam2VideoModel(Sam2Model):
|
|
1448
1466
|
input_modalities = ("video", "text")
|
|
1449
|
-
_tied_weights_keys = {
|
|
1450
|
-
"prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding"
|
|
1451
|
-
}
|
|
1452
|
-
# need to be ignored, as it's a buffer and will not be correctly detected as tied weight
|
|
1453
|
-
_keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
|
|
1454
1467
|
_keys_to_ignore_on_load_unexpected = []
|
|
1455
1468
|
_can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2VideoTwoWayAttentionBlock, index=2)}
|
|
1456
1469
|
|
|
@@ -122,7 +122,7 @@ class Sam3VisionConfig(PreTrainedConfig):
|
|
|
122
122
|
documentation from [`PreTrainedConfig`] for more information.
|
|
123
123
|
|
|
124
124
|
Args:
|
|
125
|
-
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional
|
|
125
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `Sam3ViTConfig()`):
|
|
126
126
|
Configuration for the vision backbone. This is used to instantiate the backbone using
|
|
127
127
|
`AutoModel.from_config`.
|
|
128
128
|
fpn_hidden_size (`int`, *optional*, defaults to 256):
|
|
@@ -179,6 +179,16 @@ class Sam3VisionConfig(PreTrainedConfig):
|
|
|
179
179
|
self.initializer_range = initializer_range
|
|
180
180
|
super().__init__(**kwargs)
|
|
181
181
|
|
|
182
|
+
@property
|
|
183
|
+
def image_size(self):
|
|
184
|
+
"""Image size for the vision encoder."""
|
|
185
|
+
return self.backbone_config.image_size
|
|
186
|
+
|
|
187
|
+
@image_size.setter
|
|
188
|
+
def image_size(self, value):
|
|
189
|
+
"""Set the image size and propagate to backbone."""
|
|
190
|
+
self.backbone_config.image_size = value
|
|
191
|
+
|
|
182
192
|
|
|
183
193
|
class Sam3GeometryEncoderConfig(PreTrainedConfig):
|
|
184
194
|
r"""
|
|
@@ -506,6 +516,16 @@ class Sam3Config(PreTrainedConfig):
|
|
|
506
516
|
self.initializer_range = initializer_range
|
|
507
517
|
super().__init__(**kwargs)
|
|
508
518
|
|
|
519
|
+
@property
|
|
520
|
+
def image_size(self):
|
|
521
|
+
"""Image size for the SAM3 model."""
|
|
522
|
+
return self.vision_config.image_size
|
|
523
|
+
|
|
524
|
+
@image_size.setter
|
|
525
|
+
def image_size(self, value):
|
|
526
|
+
"""Set the image size and propagate to vision config."""
|
|
527
|
+
self.vision_config.image_size = value
|
|
528
|
+
|
|
509
529
|
|
|
510
530
|
__all__ = [
|
|
511
531
|
"Sam3Config",
|
|
@@ -417,6 +417,10 @@ class Sam3ViTRotaryEmbedding(nn.Module):
|
|
|
417
417
|
# Ensure even dimension for proper axial splitting
|
|
418
418
|
if dim % 4 != 0:
|
|
419
419
|
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
|
|
420
|
+
self.end_x, self.end_y = end_x, end_y
|
|
421
|
+
self.dim = dim
|
|
422
|
+
self.rope_theta = config.rope_theta
|
|
423
|
+
self.scale = scale
|
|
420
424
|
freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
421
425
|
|
|
422
426
|
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
@@ -776,6 +780,19 @@ class Sam3PreTrainedModel(PreTrainedModel):
|
|
|
776
780
|
super()._init_weights(module)
|
|
777
781
|
if isinstance(module, Sam3ViTEmbeddings):
|
|
778
782
|
init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
|
|
783
|
+
elif isinstance(module, Sam3ViTRotaryEmbedding):
|
|
784
|
+
end_x, end_y = module.end_x, module.end_y
|
|
785
|
+
dim = module.dim
|
|
786
|
+
freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
787
|
+
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
788
|
+
x_positions = (flattened_indices % end_x) * module.scale
|
|
789
|
+
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale
|
|
790
|
+
freqs_x = torch.outer(x_positions, freqs).float()
|
|
791
|
+
freqs_y = torch.outer(y_positions, freqs).float()
|
|
792
|
+
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
793
|
+
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
794
|
+
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
|
|
795
|
+
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
|
|
779
796
|
|
|
780
797
|
|
|
781
798
|
@auto_docstring
|
|
@@ -1338,6 +1355,8 @@ class Sam3DetrEncoder(Sam3PreTrainedModel):
|
|
|
1338
1355
|
|
|
1339
1356
|
self.layers = nn.ModuleList([Sam3DetrEncoderLayer(config) for _ in range(config.num_layers)])
|
|
1340
1357
|
|
|
1358
|
+
self.post_init()
|
|
1359
|
+
|
|
1341
1360
|
def _prepare_multilevel_features(
|
|
1342
1361
|
self,
|
|
1343
1362
|
vision_features: list[torch.Tensor],
|
|
@@ -1617,6 +1636,8 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):
|
|
|
1617
1636
|
|
|
1618
1637
|
self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=False)
|
|
1619
1638
|
|
|
1639
|
+
self.post_init()
|
|
1640
|
+
|
|
1620
1641
|
@compile_compatible_method_lru_cache(maxsize=1)
|
|
1621
1642
|
def _get_coords(
|
|
1622
1643
|
self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device
|
|
@@ -1987,6 +2008,8 @@ class Sam3MaskDecoder(Sam3PreTrainedModel):
|
|
|
1987
2008
|
self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size)
|
|
1988
2009
|
self.prompt_cross_attn_dropout = nn.Dropout(config.dropout)
|
|
1989
2010
|
|
|
2011
|
+
self.post_init()
|
|
2012
|
+
|
|
1990
2013
|
@check_model_inputs
|
|
1991
2014
|
def forward(
|
|
1992
2015
|
self,
|