transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -14,9 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
import collections.abc
|
|
18
17
|
import math
|
|
19
|
-
from collections.abc import Callable
|
|
18
|
+
from collections.abc import Callable, Iterable
|
|
20
19
|
from dataclasses import dataclass
|
|
21
20
|
from typing import Optional, Union
|
|
22
21
|
|
|
@@ -40,7 +39,7 @@ from ...modeling_outputs import (
|
|
|
40
39
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
41
40
|
from ...processing_utils import Unpack
|
|
42
41
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
43
|
-
from ...utils import auto_docstring
|
|
42
|
+
from ...utils import auto_docstring, logging
|
|
44
43
|
from ...utils.generic import TransformersKwargs, check_model_inputs
|
|
45
44
|
from ..auto import AutoModel
|
|
46
45
|
from .configuration_sam3 import (
|
|
@@ -54,6 +53,9 @@ from .configuration_sam3 import (
|
|
|
54
53
|
)
|
|
55
54
|
|
|
56
55
|
|
|
56
|
+
logger = logging.get_logger(__name__)
|
|
57
|
+
|
|
58
|
+
|
|
57
59
|
@dataclass
|
|
58
60
|
@auto_docstring
|
|
59
61
|
class Sam3VisionEncoderOutput(ModelOutput):
|
|
@@ -123,8 +125,8 @@ class Sam3DETRDecoderOutput(ModelOutput):
|
|
|
123
125
|
Decoder hidden states from all layers.
|
|
124
126
|
reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`):
|
|
125
127
|
Predicted reference boxes from all decoder layers in (cx, cy, w, h) format.
|
|
126
|
-
presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size
|
|
127
|
-
Presence logits from all decoder layers
|
|
128
|
+
presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size, 1)`):
|
|
129
|
+
Presence logits from all decoder layers indicating object presence confidence.
|
|
128
130
|
hidden_states (`tuple[torch.FloatTensor]`, *optional*):
|
|
129
131
|
Tuple of hidden states from all decoder layers.
|
|
130
132
|
attentions (`tuple[torch.FloatTensor]`, *optional*):
|
|
@@ -133,7 +135,7 @@ class Sam3DETRDecoderOutput(ModelOutput):
|
|
|
133
135
|
|
|
134
136
|
intermediate_hidden_states: torch.FloatTensor = None
|
|
135
137
|
reference_boxes: torch.FloatTensor = None
|
|
136
|
-
presence_logits:
|
|
138
|
+
presence_logits: torch.FloatTensor = None
|
|
137
139
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
138
140
|
attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
139
141
|
|
|
@@ -372,6 +374,19 @@ class Sam3Attention(nn.Module):
|
|
|
372
374
|
if self.config._attn_implementation != "eager":
|
|
373
375
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
374
376
|
|
|
377
|
+
if (
|
|
378
|
+
"flash" in self.config._attn_implementation
|
|
379
|
+
and attention_mask is not None
|
|
380
|
+
and attention_mask.dtype != torch.bool
|
|
381
|
+
):
|
|
382
|
+
# Relative position bias tensors are represented as float masks and are incompatible with Flash Attention
|
|
383
|
+
# Fallback to SDPA for this call only so the rest of the model can still benefit from FA
|
|
384
|
+
attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
|
|
385
|
+
logger.warning_once(
|
|
386
|
+
"Sam3Attention: falling back to SDPA for relative-position cross-attention because "
|
|
387
|
+
"Flash Attention does not support additive bias masks."
|
|
388
|
+
)
|
|
389
|
+
|
|
375
390
|
attn_output, attn_weights = attention_interface(
|
|
376
391
|
self,
|
|
377
392
|
query,
|
|
@@ -402,6 +417,10 @@ class Sam3ViTRotaryEmbedding(nn.Module):
|
|
|
402
417
|
# Ensure even dimension for proper axial splitting
|
|
403
418
|
if dim % 4 != 0:
|
|
404
419
|
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
|
|
420
|
+
self.end_x, self.end_y = end_x, end_y
|
|
421
|
+
self.dim = dim
|
|
422
|
+
self.rope_theta = config.rope_theta
|
|
423
|
+
self.scale = scale
|
|
405
424
|
freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
406
425
|
|
|
407
426
|
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
@@ -531,8 +550,8 @@ class Sam3ViTPatchEmbeddings(nn.Module):
|
|
|
531
550
|
image_size, patch_size = config.pretrain_image_size, config.patch_size
|
|
532
551
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
|
533
552
|
|
|
534
|
-
image_size = image_size if isinstance(image_size,
|
|
535
|
-
patch_size = patch_size if isinstance(patch_size,
|
|
553
|
+
image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size)
|
|
554
|
+
patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size)
|
|
536
555
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
537
556
|
self.image_size = image_size
|
|
538
557
|
self.patch_size = patch_size
|
|
@@ -542,7 +561,7 @@ class Sam3ViTPatchEmbeddings(nn.Module):
|
|
|
542
561
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False)
|
|
543
562
|
|
|
544
563
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
545
|
-
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
|
564
|
+
embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2)
|
|
546
565
|
return embeddings
|
|
547
566
|
|
|
548
567
|
|
|
@@ -761,6 +780,19 @@ class Sam3PreTrainedModel(PreTrainedModel):
|
|
|
761
780
|
super()._init_weights(module)
|
|
762
781
|
if isinstance(module, Sam3ViTEmbeddings):
|
|
763
782
|
init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
|
|
783
|
+
elif isinstance(module, Sam3ViTRotaryEmbedding):
|
|
784
|
+
end_x, end_y = module.end_x, module.end_y
|
|
785
|
+
dim = module.dim
|
|
786
|
+
freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
787
|
+
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
788
|
+
x_positions = (flattened_indices % end_x) * module.scale
|
|
789
|
+
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale
|
|
790
|
+
freqs_x = torch.outer(x_positions, freqs).float()
|
|
791
|
+
freqs_y = torch.outer(y_positions, freqs).float()
|
|
792
|
+
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
793
|
+
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
794
|
+
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
|
|
795
|
+
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
|
|
764
796
|
|
|
765
797
|
|
|
766
798
|
@auto_docstring
|
|
@@ -938,6 +970,7 @@ class Sam3FPNLayer(nn.Module):
|
|
|
938
970
|
self.proj2 = nn.Conv2d(in_channels=fpn_dim, out_channels=fpn_dim, kernel_size=3, padding=1)
|
|
939
971
|
|
|
940
972
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
973
|
+
hidden_states = hidden_states.to(self.proj1.weight.dtype)
|
|
941
974
|
for layer in self.scale_layers:
|
|
942
975
|
hidden_states = layer(hidden_states)
|
|
943
976
|
|
|
@@ -1253,7 +1286,7 @@ class Sam3DetrEncoderLayer(nn.Module):
|
|
|
1253
1286
|
vision_feats: Tensor,
|
|
1254
1287
|
prompt_feats: Tensor,
|
|
1255
1288
|
vision_pos_encoding: Tensor,
|
|
1256
|
-
|
|
1289
|
+
prompt_cross_attn_mask: Optional[Tensor] = None,
|
|
1257
1290
|
**kwargs: Unpack[TransformersKwargs],
|
|
1258
1291
|
):
|
|
1259
1292
|
"""
|
|
@@ -1263,7 +1296,7 @@ class Sam3DetrEncoderLayer(nn.Module):
|
|
|
1263
1296
|
vision_feats: Vision features [batch_size, vision_len, hidden_size] (main hidden states)
|
|
1264
1297
|
prompt_feats: Text prompt features [batch_size, text_len, hidden_size]
|
|
1265
1298
|
vision_pos_encoding: Position encoding for vision [batch_size, vision_len, hidden_size]
|
|
1266
|
-
|
|
1299
|
+
prompt_cross_attn_mask: Cross-attention mask for prompt features
|
|
1267
1300
|
|
|
1268
1301
|
Returns:
|
|
1269
1302
|
Updated vision features [batch_size, vision_len, hidden_size]
|
|
@@ -1284,15 +1317,6 @@ class Sam3DetrEncoderLayer(nn.Module):
|
|
|
1284
1317
|
residual = hidden_states
|
|
1285
1318
|
hidden_states = self.layer_norm2(hidden_states)
|
|
1286
1319
|
|
|
1287
|
-
prompt_cross_attn_mask = None
|
|
1288
|
-
if prompt_mask is not None:
|
|
1289
|
-
prompt_cross_attn_mask = create_bidirectional_mask(
|
|
1290
|
-
config=self.config,
|
|
1291
|
-
input_embeds=hidden_states,
|
|
1292
|
-
attention_mask=prompt_mask,
|
|
1293
|
-
encoder_hidden_states=prompt_feats,
|
|
1294
|
-
)
|
|
1295
|
-
|
|
1296
1320
|
hidden_states, _ = self.cross_attn(
|
|
1297
1321
|
query=hidden_states,
|
|
1298
1322
|
key=prompt_feats,
|
|
@@ -1331,6 +1355,8 @@ class Sam3DetrEncoder(Sam3PreTrainedModel):
|
|
|
1331
1355
|
|
|
1332
1356
|
self.layers = nn.ModuleList([Sam3DetrEncoderLayer(config) for _ in range(config.num_layers)])
|
|
1333
1357
|
|
|
1358
|
+
self.post_init()
|
|
1359
|
+
|
|
1334
1360
|
def _prepare_multilevel_features(
|
|
1335
1361
|
self,
|
|
1336
1362
|
vision_features: list[torch.Tensor],
|
|
@@ -1412,13 +1438,22 @@ class Sam3DetrEncoder(Sam3PreTrainedModel):
|
|
|
1412
1438
|
spatial_shapes,
|
|
1413
1439
|
) = self._prepare_multilevel_features(vision_features, vision_pos_embeds)
|
|
1414
1440
|
|
|
1441
|
+
prompt_cross_attn_mask = None
|
|
1442
|
+
if text_mask is not None:
|
|
1443
|
+
prompt_cross_attn_mask = create_bidirectional_mask(
|
|
1444
|
+
config=self.config,
|
|
1445
|
+
input_embeds=features_flattened,
|
|
1446
|
+
attention_mask=text_mask,
|
|
1447
|
+
encoder_hidden_states=text_features,
|
|
1448
|
+
)
|
|
1449
|
+
|
|
1415
1450
|
hidden_states = features_flattened
|
|
1416
1451
|
for layer in self.layers:
|
|
1417
1452
|
hidden_states = layer(
|
|
1418
1453
|
hidden_states,
|
|
1419
1454
|
prompt_feats=text_features,
|
|
1420
1455
|
vision_pos_encoding=pos_embeds_flattened,
|
|
1421
|
-
|
|
1456
|
+
prompt_cross_attn_mask=prompt_cross_attn_mask,
|
|
1422
1457
|
**kwargs,
|
|
1423
1458
|
)
|
|
1424
1459
|
return Sam3DETREncoderOutput(
|
|
@@ -1484,31 +1519,27 @@ class Sam3DetrDecoderLayer(nn.Module):
|
|
|
1484
1519
|
text_features: torch.Tensor,
|
|
1485
1520
|
vision_features: torch.Tensor,
|
|
1486
1521
|
vision_pos_encoding: torch.Tensor,
|
|
1487
|
-
|
|
1522
|
+
text_cross_attn_mask: Optional[torch.Tensor] = None,
|
|
1488
1523
|
vision_cross_attn_mask: Optional[torch.Tensor] = None,
|
|
1489
|
-
presence_token: Optional[torch.Tensor] = None,
|
|
1490
1524
|
**kwargs: Unpack[TransformersKwargs],
|
|
1491
|
-
) ->
|
|
1525
|
+
) -> torch.Tensor:
|
|
1492
1526
|
"""
|
|
1493
1527
|
Forward pass for decoder layer.
|
|
1494
1528
|
|
|
1495
1529
|
Args:
|
|
1496
|
-
hidden_states: Query features [batch_size, num_queries, hidden_size]
|
|
1530
|
+
hidden_states: Query features [batch_size, num_queries + 1, hidden_size] (includes presence token at position 0)
|
|
1497
1531
|
query_pos: Query position embeddings [batch_size, num_queries, hidden_size]
|
|
1498
1532
|
text_features: Text features [batch_size, seq_len, hidden_size]
|
|
1499
1533
|
vision_features: Vision features [batch_size, height*width, hidden_size]
|
|
1500
1534
|
vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
|
|
1501
|
-
|
|
1502
|
-
vision_cross_attn_mask: Vision cross-attention mask
|
|
1503
|
-
presence_token: Optional presence token [batch_size, 1, hidden_size]
|
|
1535
|
+
text_cross_attn_mask: Text cross-attention mask
|
|
1536
|
+
vision_cross_attn_mask: Vision cross-attention mask, already expanded for presence token
|
|
1504
1537
|
|
|
1505
1538
|
Returns:
|
|
1506
|
-
|
|
1539
|
+
Updated hidden states (including presence token at position 0)
|
|
1507
1540
|
"""
|
|
1508
|
-
#
|
|
1509
|
-
|
|
1510
|
-
hidden_states = torch.cat([presence_token, hidden_states], dim=1)
|
|
1511
|
-
query_pos = torch.cat([torch.zeros_like(presence_token), query_pos], dim=1)
|
|
1541
|
+
# Prepend zeros to query_pos for presence token
|
|
1542
|
+
query_pos = F.pad(query_pos, (0, 0, 1, 0), mode="constant", value=0)
|
|
1512
1543
|
|
|
1513
1544
|
# Self-attention with query position encoding
|
|
1514
1545
|
residual = hidden_states
|
|
@@ -1527,15 +1558,6 @@ class Sam3DetrDecoderLayer(nn.Module):
|
|
|
1527
1558
|
residual = hidden_states
|
|
1528
1559
|
query_with_pos = hidden_states + query_pos
|
|
1529
1560
|
|
|
1530
|
-
text_cross_attn_mask = None
|
|
1531
|
-
if text_mask is not None:
|
|
1532
|
-
text_cross_attn_mask = create_bidirectional_mask(
|
|
1533
|
-
config=self.config,
|
|
1534
|
-
input_embeds=hidden_states,
|
|
1535
|
-
attention_mask=text_mask,
|
|
1536
|
-
encoder_hidden_states=text_features,
|
|
1537
|
-
)
|
|
1538
|
-
|
|
1539
1561
|
attn_output, _ = self.text_cross_attn(
|
|
1540
1562
|
query=query_with_pos,
|
|
1541
1563
|
key=text_features,
|
|
@@ -1546,20 +1568,6 @@ class Sam3DetrDecoderLayer(nn.Module):
|
|
|
1546
1568
|
hidden_states = residual + self.text_cross_attn_dropout(attn_output)
|
|
1547
1569
|
hidden_states = self.text_cross_attn_layer_norm(hidden_states)
|
|
1548
1570
|
|
|
1549
|
-
# Expand vision cross-attention mask for presence token if needed
|
|
1550
|
-
combined_vision_mask = vision_cross_attn_mask
|
|
1551
|
-
if presence_token is not None and combined_vision_mask is not None:
|
|
1552
|
-
batch_size, num_heads = combined_vision_mask.shape[:2]
|
|
1553
|
-
presence_mask = torch.zeros(
|
|
1554
|
-
batch_size,
|
|
1555
|
-
num_heads,
|
|
1556
|
-
1,
|
|
1557
|
-
combined_vision_mask.shape[-1],
|
|
1558
|
-
device=combined_vision_mask.device,
|
|
1559
|
-
dtype=combined_vision_mask.dtype,
|
|
1560
|
-
)
|
|
1561
|
-
combined_vision_mask = torch.cat([presence_mask, combined_vision_mask], dim=2)
|
|
1562
|
-
|
|
1563
1571
|
# Vision cross-attention: queries attend to vision features (with RPB)
|
|
1564
1572
|
residual = hidden_states
|
|
1565
1573
|
query_with_pos = hidden_states + query_pos
|
|
@@ -1568,7 +1576,7 @@ class Sam3DetrDecoderLayer(nn.Module):
|
|
|
1568
1576
|
query=query_with_pos,
|
|
1569
1577
|
key=key_with_pos,
|
|
1570
1578
|
value=vision_features,
|
|
1571
|
-
attention_mask=
|
|
1579
|
+
attention_mask=vision_cross_attn_mask,
|
|
1572
1580
|
**kwargs,
|
|
1573
1581
|
)
|
|
1574
1582
|
hidden_states = residual + self.vision_cross_attn_dropout(attn_output)
|
|
@@ -1580,13 +1588,7 @@ class Sam3DetrDecoderLayer(nn.Module):
|
|
|
1580
1588
|
hidden_states = residual + self.mlp_dropout(hidden_states)
|
|
1581
1589
|
hidden_states = self.mlp_layer_norm(hidden_states)
|
|
1582
1590
|
|
|
1583
|
-
|
|
1584
|
-
presence_token_out = None
|
|
1585
|
-
if presence_token is not None:
|
|
1586
|
-
presence_token_out = hidden_states[:, :1]
|
|
1587
|
-
hidden_states = hidden_states[:, 1:]
|
|
1588
|
-
|
|
1589
|
-
return hidden_states, presence_token_out
|
|
1591
|
+
return hidden_states
|
|
1590
1592
|
|
|
1591
1593
|
|
|
1592
1594
|
class Sam3DetrDecoder(Sam3PreTrainedModel):
|
|
@@ -1634,6 +1636,8 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):
|
|
|
1634
1636
|
|
|
1635
1637
|
self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=False)
|
|
1636
1638
|
|
|
1639
|
+
self.post_init()
|
|
1640
|
+
|
|
1637
1641
|
@compile_compatible_method_lru_cache(maxsize=1)
|
|
1638
1642
|
def _get_coords(
|
|
1639
1643
|
self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device
|
|
@@ -1715,11 +1719,23 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):
|
|
|
1715
1719
|
"""
|
|
1716
1720
|
batch_size = vision_features.shape[0]
|
|
1717
1721
|
|
|
1718
|
-
|
|
1722
|
+
query_embeds = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
|
1719
1723
|
reference_boxes = self.reference_points.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
|
1720
1724
|
reference_boxes = reference_boxes.sigmoid()
|
|
1721
1725
|
presence_token = self.presence_token.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
|
1722
1726
|
|
|
1727
|
+
# Concatenate presence token with query embeddings
|
|
1728
|
+
hidden_states = torch.cat([presence_token, query_embeds], dim=1)
|
|
1729
|
+
|
|
1730
|
+
text_cross_attn_mask = None
|
|
1731
|
+
if text_mask is not None:
|
|
1732
|
+
text_cross_attn_mask = create_bidirectional_mask(
|
|
1733
|
+
config=self.config,
|
|
1734
|
+
input_embeds=hidden_states,
|
|
1735
|
+
attention_mask=text_mask,
|
|
1736
|
+
encoder_hidden_states=text_features,
|
|
1737
|
+
)
|
|
1738
|
+
|
|
1723
1739
|
intermediate_outputs = []
|
|
1724
1740
|
intermediate_boxes = [reference_boxes]
|
|
1725
1741
|
intermediate_presence_logits = []
|
|
@@ -1734,43 +1750,45 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):
|
|
|
1734
1750
|
vision_cross_attn_mask = None
|
|
1735
1751
|
if spatial_shapes is not None and spatial_shapes.shape[0] == 1:
|
|
1736
1752
|
spatial_shape = (spatial_shapes[0, 0], spatial_shapes[0, 1])
|
|
1737
|
-
|
|
1753
|
+
rpb_matrix = self._get_rpb_matrix(reference_boxes, spatial_shape)
|
|
1754
|
+
# Prepend zeros row for presence token (it attends to all vision tokens equally)
|
|
1755
|
+
vision_cross_attn_mask = F.pad(rpb_matrix, (0, 0, 1, 0), mode="constant", value=0)
|
|
1738
1756
|
|
|
1739
|
-
hidden_states
|
|
1757
|
+
hidden_states = layer(
|
|
1740
1758
|
hidden_states,
|
|
1741
1759
|
query_pos=query_pos,
|
|
1742
1760
|
text_features=text_features,
|
|
1743
1761
|
vision_features=vision_features,
|
|
1744
1762
|
vision_pos_encoding=vision_pos_encoding,
|
|
1745
|
-
|
|
1763
|
+
text_cross_attn_mask=text_cross_attn_mask,
|
|
1746
1764
|
vision_cross_attn_mask=vision_cross_attn_mask,
|
|
1747
|
-
presence_token=presence_token,
|
|
1748
1765
|
**kwargs,
|
|
1749
1766
|
)
|
|
1750
1767
|
|
|
1768
|
+
# Extract query hidden states (without presence token) for box refinement
|
|
1769
|
+
query_hidden_states = hidden_states[:, 1:]
|
|
1770
|
+
|
|
1751
1771
|
# Box refinement: predict delta and update reference boxes
|
|
1752
1772
|
reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes)
|
|
1753
|
-
delta_boxes = self.box_head(self.output_layer_norm(
|
|
1773
|
+
delta_boxes = self.box_head(self.output_layer_norm(query_hidden_states))
|
|
1754
1774
|
new_reference_boxes = (delta_boxes + reference_boxes_before_sigmoid).sigmoid()
|
|
1755
1775
|
reference_boxes = new_reference_boxes.detach()
|
|
1756
1776
|
|
|
1757
|
-
intermediate_outputs.append(self.output_layer_norm(
|
|
1777
|
+
intermediate_outputs.append(self.output_layer_norm(query_hidden_states))
|
|
1758
1778
|
intermediate_boxes.append(new_reference_boxes)
|
|
1759
1779
|
|
|
1760
1780
|
# Process presence token
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
|
|
1766
|
-
|
|
1781
|
+
presence_hidden = hidden_states[:, :1]
|
|
1782
|
+
presence_logits = self.presence_head(self.presence_layer_norm(presence_hidden)).squeeze(-1)
|
|
1783
|
+
presence_logits = presence_logits.clamp(
|
|
1784
|
+
min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val
|
|
1785
|
+
)
|
|
1786
|
+
intermediate_presence_logits.append(presence_logits)
|
|
1767
1787
|
|
|
1768
1788
|
# Stack outputs from all layers
|
|
1769
1789
|
intermediate_outputs = torch.stack(intermediate_outputs)
|
|
1770
1790
|
intermediate_boxes = torch.stack(intermediate_boxes[:-1])
|
|
1771
|
-
intermediate_presence_logits = (
|
|
1772
|
-
torch.stack(intermediate_presence_logits) if intermediate_presence_logits else None
|
|
1773
|
-
)
|
|
1791
|
+
intermediate_presence_logits = torch.stack(intermediate_presence_logits)
|
|
1774
1792
|
|
|
1775
1793
|
return Sam3DETRDecoderOutput(
|
|
1776
1794
|
intermediate_hidden_states=intermediate_outputs,
|
|
@@ -1990,6 +2008,8 @@ class Sam3MaskDecoder(Sam3PreTrainedModel):
|
|
|
1990
2008
|
self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size)
|
|
1991
2009
|
self.prompt_cross_attn_dropout = nn.Dropout(config.dropout)
|
|
1992
2010
|
|
|
2011
|
+
self.post_init()
|
|
2012
|
+
|
|
1993
2013
|
@check_model_inputs
|
|
1994
2014
|
def forward(
|
|
1995
2015
|
self,
|
|
@@ -107,7 +107,12 @@ class Sam3TrackerFeedForward(nn.Module):
|
|
|
107
107
|
return hidden_states
|
|
108
108
|
|
|
109
109
|
|
|
110
|
-
@auto_docstring
|
|
110
|
+
@auto_docstring(
|
|
111
|
+
custom_intro="""
|
|
112
|
+
Segment Anything Model 3 (SAM 3) for generating segmentation masks, given an input image and
|
|
113
|
+
input points and labels, boxes, or masks.
|
|
114
|
+
"""
|
|
115
|
+
)
|
|
111
116
|
class Sam3TrackerPreTrainedModel(PreTrainedModel):
|
|
112
117
|
config_class = Sam3TrackerConfig
|
|
113
118
|
base_model_prefix = "sam3_tracker"
|
|
@@ -123,6 +128,8 @@ class Sam3TrackerPreTrainedModel(PreTrainedModel):
|
|
|
123
128
|
if isinstance(module, Sam3TrackerModel):
|
|
124
129
|
if module.no_memory_embedding is not None:
|
|
125
130
|
init.zeros_(module.no_memory_embedding)
|
|
131
|
+
elif isinstance(module, Sam3TrackerPositionalEmbedding):
|
|
132
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
126
133
|
|
|
127
134
|
|
|
128
135
|
class Sam3TrackerPositionalEmbedding(nn.Module):
|
|
@@ -136,7 +136,12 @@ class Sam3TrackerFeedForward(Sam2FeedForward):
|
|
|
136
136
|
pass
|
|
137
137
|
|
|
138
138
|
|
|
139
|
-
@auto_docstring
|
|
139
|
+
@auto_docstring(
|
|
140
|
+
custom_intro="""
|
|
141
|
+
Segment Anything Model 3 (SAM 3) for generating segmentation masks, given an input image and
|
|
142
|
+
input points and labels, boxes, or masks.
|
|
143
|
+
"""
|
|
144
|
+
)
|
|
140
145
|
class Sam3TrackerPreTrainedModel(Sam2PreTrainedModel):
|
|
141
146
|
@torch.no_grad()
|
|
142
147
|
def _init_weights(self, module):
|
|
@@ -144,6 +149,8 @@ class Sam3TrackerPreTrainedModel(Sam2PreTrainedModel):
|
|
|
144
149
|
if isinstance(module, Sam3TrackerModel):
|
|
145
150
|
if module.no_memory_embedding is not None:
|
|
146
151
|
init.zeros_(module.no_memory_embedding)
|
|
152
|
+
elif isinstance(module, Sam3TrackerPositionalEmbedding):
|
|
153
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
147
154
|
|
|
148
155
|
|
|
149
156
|
class Sam3TrackerPositionalEmbedding(Sam2PositionalEmbedding):
|
|
@@ -397,5 +397,30 @@ class Sam3TrackerVideoConfig(PreTrainedConfig):
|
|
|
397
397
|
|
|
398
398
|
super().__init__(**kwargs)
|
|
399
399
|
|
|
400
|
+
@property
|
|
401
|
+
def image_size(self):
|
|
402
|
+
"""Image size for the tracker video model."""
|
|
403
|
+
return self.vision_config.image_size
|
|
404
|
+
|
|
405
|
+
@image_size.setter
|
|
406
|
+
def image_size(self, value):
|
|
407
|
+
"""Set the image size and propagate to sub-configs. Calculates feature sizes based on patch_size."""
|
|
408
|
+
self.prompt_encoder_config.image_size = value
|
|
409
|
+
self.vision_config.image_size = value
|
|
410
|
+
|
|
411
|
+
patch_size = self.vision_config.backbone_config.patch_size
|
|
412
|
+
self.vision_config.backbone_feature_sizes = [
|
|
413
|
+
[4 * value // patch_size, 4 * value // patch_size],
|
|
414
|
+
[2 * value // patch_size, 2 * value // patch_size],
|
|
415
|
+
[value // patch_size, value // patch_size],
|
|
416
|
+
]
|
|
417
|
+
self.memory_attention_rope_feat_sizes = [
|
|
418
|
+
value // patch_size,
|
|
419
|
+
value // patch_size,
|
|
420
|
+
]
|
|
421
|
+
|
|
422
|
+
# keep the image_size in the __dict__ to save the value in the config file (backward compatibility)
|
|
423
|
+
self.__dict__["image_size"] = value
|
|
424
|
+
|
|
400
425
|
|
|
401
426
|
__all__ = ["Sam3TrackerVideoMaskDecoderConfig", "Sam3TrackerVideoPromptEncoderConfig", "Sam3TrackerVideoConfig"]
|
|
@@ -213,7 +213,7 @@ class Sam3TrackerVideoInferenceSession:
|
|
|
213
213
|
device_inputs = {}
|
|
214
214
|
for key, value in inputs.items():
|
|
215
215
|
if isinstance(value, torch.Tensor):
|
|
216
|
-
device_inputs[key] = value.to(self.inference_device, non_blocking=
|
|
216
|
+
device_inputs[key] = value.to(self.inference_device, non_blocking=False)
|
|
217
217
|
else:
|
|
218
218
|
device_inputs[key] = value
|
|
219
219
|
self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
|
|
@@ -692,6 +692,12 @@ class Sam3TrackerVideoPreTrainedModel(PreTrainedModel):
|
|
|
692
692
|
if isinstance(module, Sam3TrackerVideoMemoryFuserCXBlock):
|
|
693
693
|
if module.scale is not None:
|
|
694
694
|
init.zeros_(module.scale)
|
|
695
|
+
elif isinstance(module, Sam3TrackerVideoVisionRotaryEmbedding):
|
|
696
|
+
inv_freq = module.create_inv_freq()
|
|
697
|
+
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
|
|
698
|
+
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
|
|
699
|
+
elif isinstance(module, Sam3TrackerVideoPositionalEmbedding):
|
|
700
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
695
701
|
|
|
696
702
|
|
|
697
703
|
class Sam3TrackerVideoVisionRotaryEmbedding(nn.Module):
|
|
@@ -702,24 +708,17 @@ class Sam3TrackerVideoVisionRotaryEmbedding(nn.Module):
|
|
|
702
708
|
|
|
703
709
|
def __init__(self, config: Sam3TrackerVideoConfig):
|
|
704
710
|
super().__init__()
|
|
705
|
-
dim = config.memory_attention_hidden_size // (
|
|
711
|
+
self.dim = config.memory_attention_hidden_size // (
|
|
706
712
|
config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
|
|
707
713
|
)
|
|
708
714
|
# Ensure even dimension for proper axial splitting
|
|
709
|
-
if dim % 4 != 0:
|
|
715
|
+
if self.dim % 4 != 0:
|
|
710
716
|
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
|
|
711
|
-
end_x, end_y = config.memory_attention_rope_feat_sizes
|
|
712
|
-
|
|
717
|
+
self.end_x, self.end_y = config.memory_attention_rope_feat_sizes
|
|
718
|
+
self.memory_attention_rope_theta = config.memory_attention_rope_theta
|
|
713
719
|
|
|
714
|
-
# Generate 2D position indices for axial rotary embedding
|
|
715
|
-
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
716
|
-
x_positions = flattened_indices % end_x
|
|
717
|
-
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
|
|
718
|
-
freqs_x = torch.outer(x_positions, freqs).float()
|
|
719
|
-
freqs_y = torch.outer(y_positions, freqs).float()
|
|
720
|
-
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
721
|
-
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
722
720
|
# directly register the cos and sin embeddings as we have a fixed feature shape
|
|
721
|
+
inv_freq = self.create_inv_freq()
|
|
723
722
|
self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
|
|
724
723
|
self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
|
|
725
724
|
|
|
@@ -728,6 +727,20 @@ class Sam3TrackerVideoVisionRotaryEmbedding(nn.Module):
|
|
|
728
727
|
# As the feature map size is fixed, we can just return the pre-computed embeddings.
|
|
729
728
|
return self.rope_embeddings_cos, self.rope_embeddings_sin
|
|
730
729
|
|
|
730
|
+
def create_inv_freq(self):
|
|
731
|
+
freqs = 1.0 / (
|
|
732
|
+
self.memory_attention_rope_theta ** (torch.arange(0, self.dim, 4)[: (self.dim // 4)].float() / self.dim)
|
|
733
|
+
)
|
|
734
|
+
# Generate 2D position indices for axial rotary embedding
|
|
735
|
+
flattened_indices = torch.arange(self.end_x * self.end_y, dtype=torch.long)
|
|
736
|
+
x_positions = flattened_indices % self.end_x
|
|
737
|
+
y_positions = torch.div(flattened_indices, self.end_x, rounding_mode="floor")
|
|
738
|
+
freqs_x = torch.outer(x_positions, freqs).float()
|
|
739
|
+
freqs_y = torch.outer(y_positions, freqs).float()
|
|
740
|
+
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
741
|
+
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
742
|
+
return inv_freq
|
|
743
|
+
|
|
731
744
|
|
|
732
745
|
def rotate_pairwise(x):
|
|
733
746
|
"""
|
|
@@ -1567,8 +1580,6 @@ class Sam3TrackerVideoModel(Sam3TrackerVideoPreTrainedModel):
|
|
|
1567
1580
|
input_modalities = ("video", "text")
|
|
1568
1581
|
_can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam3TrackerVideoTwoWayAttentionBlock, index=2)}
|
|
1569
1582
|
_keys_to_ignore_on_load_unexpected = [r"^detector_model."]
|
|
1570
|
-
_tied_weights_keys = {}
|
|
1571
|
-
_keys_to_ignore_on_load_missing = []
|
|
1572
1583
|
_checkpoint_conversion_mapping = {
|
|
1573
1584
|
r"tracker_model.(.+)": r"\1", # the regex allows to remove the prefix, and add it back in revert mode
|
|
1574
1585
|
"detector_model.vision_encoder.backbone.": "vision_encoder.backbone.",
|
|
@@ -1719,6 +1730,7 @@ class Sam3TrackerVideoModel(Sam3TrackerVideoPreTrainedModel):
|
|
|
1719
1730
|
frame: Optional[torch.Tensor] = None,
|
|
1720
1731
|
reverse: bool = False,
|
|
1721
1732
|
run_mem_encoder: bool = True,
|
|
1733
|
+
**kwargs,
|
|
1722
1734
|
) -> Sam3TrackerVideoSegmentationOutput:
|
|
1723
1735
|
r"""
|
|
1724
1736
|
inference_session (`Sam3TrackerVideoInferenceSession`):
|
|
@@ -353,6 +353,31 @@ class Sam3TrackerVideoConfig(PreTrainedConfig):
|
|
|
353
353
|
|
|
354
354
|
super().__init__(**kwargs)
|
|
355
355
|
|
|
356
|
+
@property
|
|
357
|
+
def image_size(self):
|
|
358
|
+
"""Image size for the tracker video model."""
|
|
359
|
+
return self.vision_config.image_size
|
|
360
|
+
|
|
361
|
+
@image_size.setter
|
|
362
|
+
def image_size(self, value):
|
|
363
|
+
"""Set the image size and propagate to sub-configs. Calculates feature sizes based on patch_size."""
|
|
364
|
+
self.prompt_encoder_config.image_size = value
|
|
365
|
+
self.vision_config.image_size = value
|
|
366
|
+
|
|
367
|
+
patch_size = self.vision_config.backbone_config.patch_size
|
|
368
|
+
self.vision_config.backbone_feature_sizes = [
|
|
369
|
+
[4 * value // patch_size, 4 * value // patch_size],
|
|
370
|
+
[2 * value // patch_size, 2 * value // patch_size],
|
|
371
|
+
[value // patch_size, value // patch_size],
|
|
372
|
+
]
|
|
373
|
+
self.memory_attention_rope_feat_sizes = [
|
|
374
|
+
value // patch_size,
|
|
375
|
+
value // patch_size,
|
|
376
|
+
]
|
|
377
|
+
|
|
378
|
+
# keep the image_size in the __dict__ to save the value in the config file (backward compatibility)
|
|
379
|
+
self.__dict__["image_size"] = value
|
|
380
|
+
|
|
356
381
|
|
|
357
382
|
class Sam3TrackerVideoInferenceCache(Sam2VideoInferenceCache):
|
|
358
383
|
pass
|
|
@@ -461,8 +486,6 @@ class Sam3TrackerVideoModel(Sam2VideoModel):
|
|
|
461
486
|
"tracker_neck.": "vision_encoder.neck.",
|
|
462
487
|
}
|
|
463
488
|
_keys_to_ignore_on_load_unexpected = [r"^detector_model."]
|
|
464
|
-
_tied_weights_keys = {}
|
|
465
|
-
_keys_to_ignore_on_load_missing = []
|
|
466
489
|
|
|
467
490
|
def __init__(self, config: Sam3TrackerVideoConfig, remove_vision_encoder: bool = False):
|
|
468
491
|
r"""
|
|
@@ -96,6 +96,9 @@ class Sam3VideoConfig(PreTrainedConfig):
|
|
|
96
96
|
>>> # Initializing a SAM3 Video configuration with default detector and tracker
|
|
97
97
|
>>> configuration = Sam3VideoConfig()
|
|
98
98
|
|
|
99
|
+
>>> # Changing image size for custom resolution inference (automatically propagates to all nested configs)
|
|
100
|
+
>>> configuration.image_size = 560
|
|
101
|
+
|
|
99
102
|
>>> # Initializing a model from the configuration
|
|
100
103
|
>>> model = Sam3VideoModel(configuration)
|
|
101
104
|
|
|
@@ -225,5 +228,16 @@ class Sam3VideoConfig(PreTrainedConfig):
|
|
|
225
228
|
self.high_conf_thresh = high_conf_thresh
|
|
226
229
|
self.high_iou_thresh = high_iou_thresh
|
|
227
230
|
|
|
231
|
+
@property
|
|
232
|
+
def image_size(self):
|
|
233
|
+
"""Image size for the video model."""
|
|
234
|
+
return self.detector_config.image_size
|
|
235
|
+
|
|
236
|
+
@image_size.setter
|
|
237
|
+
def image_size(self, value):
|
|
238
|
+
"""Recursively propagate the image size to detector and tracker configs."""
|
|
239
|
+
self.detector_config.image_size = value
|
|
240
|
+
self.tracker_config.image_size = value
|
|
241
|
+
|
|
228
242
|
|
|
229
243
|
__all__ = ["Sam3VideoConfig"]
|
|
@@ -33,7 +33,7 @@ from .configuration_sam3_video import Sam3VideoConfig
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
if is_kernels_available():
|
|
36
|
-
from
|
|
36
|
+
from ...integrations.hub_kernels import get_kernel
|
|
37
37
|
|
|
38
38
|
logger = logging.get_logger(__name__)
|
|
39
39
|
|
|
@@ -505,8 +505,6 @@ class Sam3VideoPreTrainedModel(PreTrainedModel):
|
|
|
505
505
|
|
|
506
506
|
@auto_docstring
|
|
507
507
|
class Sam3VideoModel(Sam3VideoPreTrainedModel):
|
|
508
|
-
all_tied_weights_keys = {}
|
|
509
|
-
|
|
510
508
|
def __init__(self, config: Sam3VideoConfig):
|
|
511
509
|
super().__init__(config)
|
|
512
510
|
self.config = config
|
|
@@ -542,6 +540,8 @@ class Sam3VideoModel(Sam3VideoPreTrainedModel):
|
|
|
542
540
|
|
|
543
541
|
self.tracker_neck = Sam3VisionNeck(config.detector_config.vision_config)
|
|
544
542
|
|
|
543
|
+
self.post_init()
|
|
544
|
+
|
|
545
545
|
def get_vision_features_for_tracker(self, vision_embeds: torch.Tensor):
|
|
546
546
|
hidden_states = vision_embeds.last_hidden_state
|
|
547
547
|
batch_size = hidden_states.shape[0]
|
|
@@ -1697,6 +1697,7 @@ class Sam3VideoModel(Sam3VideoPreTrainedModel):
|
|
|
1697
1697
|
frame_idx: Optional[int] = None,
|
|
1698
1698
|
frame: Optional[torch.Tensor] = None,
|
|
1699
1699
|
reverse: bool = False,
|
|
1700
|
+
**kwargs,
|
|
1700
1701
|
):
|
|
1701
1702
|
r"""
|
|
1702
1703
|
inference_session (`Sam3VideoInferenceSession`):
|