transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -14,9 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
import collections.abc
|
|
18
17
|
import math
|
|
19
|
-
from collections.abc import Callable
|
|
18
|
+
from collections.abc import Callable, Iterable
|
|
20
19
|
from dataclasses import dataclass
|
|
21
20
|
from typing import Optional, Union
|
|
22
21
|
|
|
@@ -40,7 +39,7 @@ from ...modeling_outputs import (
|
|
|
40
39
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
41
40
|
from ...processing_utils import Unpack
|
|
42
41
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
43
|
-
from ...utils import auto_docstring
|
|
42
|
+
from ...utils import auto_docstring, logging
|
|
44
43
|
from ...utils.generic import TransformersKwargs, check_model_inputs
|
|
45
44
|
from ..auto import AutoModel
|
|
46
45
|
from .configuration_sam3 import (
|
|
@@ -54,6 +53,9 @@ from .configuration_sam3 import (
|
|
|
54
53
|
)
|
|
55
54
|
|
|
56
55
|
|
|
56
|
+
logger = logging.get_logger(__name__)
|
|
57
|
+
|
|
58
|
+
|
|
57
59
|
@dataclass
|
|
58
60
|
@auto_docstring
|
|
59
61
|
class Sam3VisionEncoderOutput(ModelOutput):
|
|
@@ -123,8 +125,8 @@ class Sam3DETRDecoderOutput(ModelOutput):
|
|
|
123
125
|
Decoder hidden states from all layers.
|
|
124
126
|
reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`):
|
|
125
127
|
Predicted reference boxes from all decoder layers in (cx, cy, w, h) format.
|
|
126
|
-
presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size
|
|
127
|
-
Presence logits from all decoder layers
|
|
128
|
+
presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size, 1)`):
|
|
129
|
+
Presence logits from all decoder layers indicating object presence confidence.
|
|
128
130
|
hidden_states (`tuple[torch.FloatTensor]`, *optional*):
|
|
129
131
|
Tuple of hidden states from all decoder layers.
|
|
130
132
|
attentions (`tuple[torch.FloatTensor]`, *optional*):
|
|
@@ -133,7 +135,7 @@ class Sam3DETRDecoderOutput(ModelOutput):
|
|
|
133
135
|
|
|
134
136
|
intermediate_hidden_states: torch.FloatTensor = None
|
|
135
137
|
reference_boxes: torch.FloatTensor = None
|
|
136
|
-
presence_logits:
|
|
138
|
+
presence_logits: torch.FloatTensor = None
|
|
137
139
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
138
140
|
attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
139
141
|
|
|
@@ -372,6 +374,19 @@ class Sam3Attention(nn.Module):
|
|
|
372
374
|
if self.config._attn_implementation != "eager":
|
|
373
375
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
374
376
|
|
|
377
|
+
if (
|
|
378
|
+
"flash" in self.config._attn_implementation
|
|
379
|
+
and attention_mask is not None
|
|
380
|
+
and attention_mask.dtype != torch.bool
|
|
381
|
+
):
|
|
382
|
+
# Relative position bias tensors are represented as float masks and are incompatible with Flash Attention
|
|
383
|
+
# Fallback to SDPA for this call only so the rest of the model can still benefit from FA
|
|
384
|
+
attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
|
|
385
|
+
logger.warning_once(
|
|
386
|
+
"Sam3Attention: falling back to SDPA for relative-position cross-attention because "
|
|
387
|
+
"Flash Attention does not support additive bias masks."
|
|
388
|
+
)
|
|
389
|
+
|
|
375
390
|
attn_output, attn_weights = attention_interface(
|
|
376
391
|
self,
|
|
377
392
|
query,
|
|
@@ -531,8 +546,8 @@ class Sam3ViTPatchEmbeddings(nn.Module):
|
|
|
531
546
|
image_size, patch_size = config.pretrain_image_size, config.patch_size
|
|
532
547
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
|
533
548
|
|
|
534
|
-
image_size = image_size if isinstance(image_size,
|
|
535
|
-
patch_size = patch_size if isinstance(patch_size,
|
|
549
|
+
image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size)
|
|
550
|
+
patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size)
|
|
536
551
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
537
552
|
self.image_size = image_size
|
|
538
553
|
self.patch_size = patch_size
|
|
@@ -542,7 +557,7 @@ class Sam3ViTPatchEmbeddings(nn.Module):
|
|
|
542
557
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False)
|
|
543
558
|
|
|
544
559
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
545
|
-
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
|
560
|
+
embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2)
|
|
546
561
|
return embeddings
|
|
547
562
|
|
|
548
563
|
|
|
@@ -938,6 +953,7 @@ class Sam3FPNLayer(nn.Module):
|
|
|
938
953
|
self.proj2 = nn.Conv2d(in_channels=fpn_dim, out_channels=fpn_dim, kernel_size=3, padding=1)
|
|
939
954
|
|
|
940
955
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
956
|
+
hidden_states = hidden_states.to(self.proj1.weight.dtype)
|
|
941
957
|
for layer in self.scale_layers:
|
|
942
958
|
hidden_states = layer(hidden_states)
|
|
943
959
|
|
|
@@ -1253,7 +1269,7 @@ class Sam3DetrEncoderLayer(nn.Module):
|
|
|
1253
1269
|
vision_feats: Tensor,
|
|
1254
1270
|
prompt_feats: Tensor,
|
|
1255
1271
|
vision_pos_encoding: Tensor,
|
|
1256
|
-
|
|
1272
|
+
prompt_cross_attn_mask: Optional[Tensor] = None,
|
|
1257
1273
|
**kwargs: Unpack[TransformersKwargs],
|
|
1258
1274
|
):
|
|
1259
1275
|
"""
|
|
@@ -1263,7 +1279,7 @@ class Sam3DetrEncoderLayer(nn.Module):
|
|
|
1263
1279
|
vision_feats: Vision features [batch_size, vision_len, hidden_size] (main hidden states)
|
|
1264
1280
|
prompt_feats: Text prompt features [batch_size, text_len, hidden_size]
|
|
1265
1281
|
vision_pos_encoding: Position encoding for vision [batch_size, vision_len, hidden_size]
|
|
1266
|
-
|
|
1282
|
+
prompt_cross_attn_mask: Cross-attention mask for prompt features
|
|
1267
1283
|
|
|
1268
1284
|
Returns:
|
|
1269
1285
|
Updated vision features [batch_size, vision_len, hidden_size]
|
|
@@ -1284,15 +1300,6 @@ class Sam3DetrEncoderLayer(nn.Module):
|
|
|
1284
1300
|
residual = hidden_states
|
|
1285
1301
|
hidden_states = self.layer_norm2(hidden_states)
|
|
1286
1302
|
|
|
1287
|
-
prompt_cross_attn_mask = None
|
|
1288
|
-
if prompt_mask is not None:
|
|
1289
|
-
prompt_cross_attn_mask = create_bidirectional_mask(
|
|
1290
|
-
config=self.config,
|
|
1291
|
-
input_embeds=hidden_states,
|
|
1292
|
-
attention_mask=prompt_mask,
|
|
1293
|
-
encoder_hidden_states=prompt_feats,
|
|
1294
|
-
)
|
|
1295
|
-
|
|
1296
1303
|
hidden_states, _ = self.cross_attn(
|
|
1297
1304
|
query=hidden_states,
|
|
1298
1305
|
key=prompt_feats,
|
|
@@ -1412,13 +1419,22 @@ class Sam3DetrEncoder(Sam3PreTrainedModel):
|
|
|
1412
1419
|
spatial_shapes,
|
|
1413
1420
|
) = self._prepare_multilevel_features(vision_features, vision_pos_embeds)
|
|
1414
1421
|
|
|
1422
|
+
prompt_cross_attn_mask = None
|
|
1423
|
+
if text_mask is not None:
|
|
1424
|
+
prompt_cross_attn_mask = create_bidirectional_mask(
|
|
1425
|
+
config=self.config,
|
|
1426
|
+
input_embeds=features_flattened,
|
|
1427
|
+
attention_mask=text_mask,
|
|
1428
|
+
encoder_hidden_states=text_features,
|
|
1429
|
+
)
|
|
1430
|
+
|
|
1415
1431
|
hidden_states = features_flattened
|
|
1416
1432
|
for layer in self.layers:
|
|
1417
1433
|
hidden_states = layer(
|
|
1418
1434
|
hidden_states,
|
|
1419
1435
|
prompt_feats=text_features,
|
|
1420
1436
|
vision_pos_encoding=pos_embeds_flattened,
|
|
1421
|
-
|
|
1437
|
+
prompt_cross_attn_mask=prompt_cross_attn_mask,
|
|
1422
1438
|
**kwargs,
|
|
1423
1439
|
)
|
|
1424
1440
|
return Sam3DETREncoderOutput(
|
|
@@ -1484,31 +1500,27 @@ class Sam3DetrDecoderLayer(nn.Module):
|
|
|
1484
1500
|
text_features: torch.Tensor,
|
|
1485
1501
|
vision_features: torch.Tensor,
|
|
1486
1502
|
vision_pos_encoding: torch.Tensor,
|
|
1487
|
-
|
|
1503
|
+
text_cross_attn_mask: Optional[torch.Tensor] = None,
|
|
1488
1504
|
vision_cross_attn_mask: Optional[torch.Tensor] = None,
|
|
1489
|
-
presence_token: Optional[torch.Tensor] = None,
|
|
1490
1505
|
**kwargs: Unpack[TransformersKwargs],
|
|
1491
|
-
) ->
|
|
1506
|
+
) -> torch.Tensor:
|
|
1492
1507
|
"""
|
|
1493
1508
|
Forward pass for decoder layer.
|
|
1494
1509
|
|
|
1495
1510
|
Args:
|
|
1496
|
-
hidden_states: Query features [batch_size, num_queries, hidden_size]
|
|
1511
|
+
hidden_states: Query features [batch_size, num_queries + 1, hidden_size] (includes presence token at position 0)
|
|
1497
1512
|
query_pos: Query position embeddings [batch_size, num_queries, hidden_size]
|
|
1498
1513
|
text_features: Text features [batch_size, seq_len, hidden_size]
|
|
1499
1514
|
vision_features: Vision features [batch_size, height*width, hidden_size]
|
|
1500
1515
|
vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
|
|
1501
|
-
|
|
1502
|
-
vision_cross_attn_mask: Vision cross-attention mask
|
|
1503
|
-
presence_token: Optional presence token [batch_size, 1, hidden_size]
|
|
1516
|
+
text_cross_attn_mask: Text cross-attention mask
|
|
1517
|
+
vision_cross_attn_mask: Vision cross-attention mask, already expanded for presence token
|
|
1504
1518
|
|
|
1505
1519
|
Returns:
|
|
1506
|
-
|
|
1520
|
+
Updated hidden states (including presence token at position 0)
|
|
1507
1521
|
"""
|
|
1508
|
-
#
|
|
1509
|
-
|
|
1510
|
-
hidden_states = torch.cat([presence_token, hidden_states], dim=1)
|
|
1511
|
-
query_pos = torch.cat([torch.zeros_like(presence_token), query_pos], dim=1)
|
|
1522
|
+
# Prepend zeros to query_pos for presence token
|
|
1523
|
+
query_pos = F.pad(query_pos, (0, 0, 1, 0), mode="constant", value=0)
|
|
1512
1524
|
|
|
1513
1525
|
# Self-attention with query position encoding
|
|
1514
1526
|
residual = hidden_states
|
|
@@ -1527,15 +1539,6 @@ class Sam3DetrDecoderLayer(nn.Module):
|
|
|
1527
1539
|
residual = hidden_states
|
|
1528
1540
|
query_with_pos = hidden_states + query_pos
|
|
1529
1541
|
|
|
1530
|
-
text_cross_attn_mask = None
|
|
1531
|
-
if text_mask is not None:
|
|
1532
|
-
text_cross_attn_mask = create_bidirectional_mask(
|
|
1533
|
-
config=self.config,
|
|
1534
|
-
input_embeds=hidden_states,
|
|
1535
|
-
attention_mask=text_mask,
|
|
1536
|
-
encoder_hidden_states=text_features,
|
|
1537
|
-
)
|
|
1538
|
-
|
|
1539
1542
|
attn_output, _ = self.text_cross_attn(
|
|
1540
1543
|
query=query_with_pos,
|
|
1541
1544
|
key=text_features,
|
|
@@ -1546,20 +1549,6 @@ class Sam3DetrDecoderLayer(nn.Module):
|
|
|
1546
1549
|
hidden_states = residual + self.text_cross_attn_dropout(attn_output)
|
|
1547
1550
|
hidden_states = self.text_cross_attn_layer_norm(hidden_states)
|
|
1548
1551
|
|
|
1549
|
-
# Expand vision cross-attention mask for presence token if needed
|
|
1550
|
-
combined_vision_mask = vision_cross_attn_mask
|
|
1551
|
-
if presence_token is not None and combined_vision_mask is not None:
|
|
1552
|
-
batch_size, num_heads = combined_vision_mask.shape[:2]
|
|
1553
|
-
presence_mask = torch.zeros(
|
|
1554
|
-
batch_size,
|
|
1555
|
-
num_heads,
|
|
1556
|
-
1,
|
|
1557
|
-
combined_vision_mask.shape[-1],
|
|
1558
|
-
device=combined_vision_mask.device,
|
|
1559
|
-
dtype=combined_vision_mask.dtype,
|
|
1560
|
-
)
|
|
1561
|
-
combined_vision_mask = torch.cat([presence_mask, combined_vision_mask], dim=2)
|
|
1562
|
-
|
|
1563
1552
|
# Vision cross-attention: queries attend to vision features (with RPB)
|
|
1564
1553
|
residual = hidden_states
|
|
1565
1554
|
query_with_pos = hidden_states + query_pos
|
|
@@ -1568,7 +1557,7 @@ class Sam3DetrDecoderLayer(nn.Module):
|
|
|
1568
1557
|
query=query_with_pos,
|
|
1569
1558
|
key=key_with_pos,
|
|
1570
1559
|
value=vision_features,
|
|
1571
|
-
attention_mask=
|
|
1560
|
+
attention_mask=vision_cross_attn_mask,
|
|
1572
1561
|
**kwargs,
|
|
1573
1562
|
)
|
|
1574
1563
|
hidden_states = residual + self.vision_cross_attn_dropout(attn_output)
|
|
@@ -1580,13 +1569,7 @@ class Sam3DetrDecoderLayer(nn.Module):
|
|
|
1580
1569
|
hidden_states = residual + self.mlp_dropout(hidden_states)
|
|
1581
1570
|
hidden_states = self.mlp_layer_norm(hidden_states)
|
|
1582
1571
|
|
|
1583
|
-
|
|
1584
|
-
presence_token_out = None
|
|
1585
|
-
if presence_token is not None:
|
|
1586
|
-
presence_token_out = hidden_states[:, :1]
|
|
1587
|
-
hidden_states = hidden_states[:, 1:]
|
|
1588
|
-
|
|
1589
|
-
return hidden_states, presence_token_out
|
|
1572
|
+
return hidden_states
|
|
1590
1573
|
|
|
1591
1574
|
|
|
1592
1575
|
class Sam3DetrDecoder(Sam3PreTrainedModel):
|
|
@@ -1715,11 +1698,23 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):
|
|
|
1715
1698
|
"""
|
|
1716
1699
|
batch_size = vision_features.shape[0]
|
|
1717
1700
|
|
|
1718
|
-
|
|
1701
|
+
query_embeds = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
|
1719
1702
|
reference_boxes = self.reference_points.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
|
1720
1703
|
reference_boxes = reference_boxes.sigmoid()
|
|
1721
1704
|
presence_token = self.presence_token.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
|
1722
1705
|
|
|
1706
|
+
# Concatenate presence token with query embeddings
|
|
1707
|
+
hidden_states = torch.cat([presence_token, query_embeds], dim=1)
|
|
1708
|
+
|
|
1709
|
+
text_cross_attn_mask = None
|
|
1710
|
+
if text_mask is not None:
|
|
1711
|
+
text_cross_attn_mask = create_bidirectional_mask(
|
|
1712
|
+
config=self.config,
|
|
1713
|
+
input_embeds=hidden_states,
|
|
1714
|
+
attention_mask=text_mask,
|
|
1715
|
+
encoder_hidden_states=text_features,
|
|
1716
|
+
)
|
|
1717
|
+
|
|
1723
1718
|
intermediate_outputs = []
|
|
1724
1719
|
intermediate_boxes = [reference_boxes]
|
|
1725
1720
|
intermediate_presence_logits = []
|
|
@@ -1734,43 +1729,45 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):
|
|
|
1734
1729
|
vision_cross_attn_mask = None
|
|
1735
1730
|
if spatial_shapes is not None and spatial_shapes.shape[0] == 1:
|
|
1736
1731
|
spatial_shape = (spatial_shapes[0, 0], spatial_shapes[0, 1])
|
|
1737
|
-
|
|
1732
|
+
rpb_matrix = self._get_rpb_matrix(reference_boxes, spatial_shape)
|
|
1733
|
+
# Prepend zeros row for presence token (it attends to all vision tokens equally)
|
|
1734
|
+
vision_cross_attn_mask = F.pad(rpb_matrix, (0, 0, 1, 0), mode="constant", value=0)
|
|
1738
1735
|
|
|
1739
|
-
hidden_states
|
|
1736
|
+
hidden_states = layer(
|
|
1740
1737
|
hidden_states,
|
|
1741
1738
|
query_pos=query_pos,
|
|
1742
1739
|
text_features=text_features,
|
|
1743
1740
|
vision_features=vision_features,
|
|
1744
1741
|
vision_pos_encoding=vision_pos_encoding,
|
|
1745
|
-
|
|
1742
|
+
text_cross_attn_mask=text_cross_attn_mask,
|
|
1746
1743
|
vision_cross_attn_mask=vision_cross_attn_mask,
|
|
1747
|
-
presence_token=presence_token,
|
|
1748
1744
|
**kwargs,
|
|
1749
1745
|
)
|
|
1750
1746
|
|
|
1747
|
+
# Extract query hidden states (without presence token) for box refinement
|
|
1748
|
+
query_hidden_states = hidden_states[:, 1:]
|
|
1749
|
+
|
|
1751
1750
|
# Box refinement: predict delta and update reference boxes
|
|
1752
1751
|
reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes)
|
|
1753
|
-
delta_boxes = self.box_head(self.output_layer_norm(
|
|
1752
|
+
delta_boxes = self.box_head(self.output_layer_norm(query_hidden_states))
|
|
1754
1753
|
new_reference_boxes = (delta_boxes + reference_boxes_before_sigmoid).sigmoid()
|
|
1755
1754
|
reference_boxes = new_reference_boxes.detach()
|
|
1756
1755
|
|
|
1757
|
-
intermediate_outputs.append(self.output_layer_norm(
|
|
1756
|
+
intermediate_outputs.append(self.output_layer_norm(query_hidden_states))
|
|
1758
1757
|
intermediate_boxes.append(new_reference_boxes)
|
|
1759
1758
|
|
|
1760
1759
|
# Process presence token
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
|
|
1766
|
-
|
|
1760
|
+
presence_hidden = hidden_states[:, :1]
|
|
1761
|
+
presence_logits = self.presence_head(self.presence_layer_norm(presence_hidden)).squeeze(-1)
|
|
1762
|
+
presence_logits = presence_logits.clamp(
|
|
1763
|
+
min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val
|
|
1764
|
+
)
|
|
1765
|
+
intermediate_presence_logits.append(presence_logits)
|
|
1767
1766
|
|
|
1768
1767
|
# Stack outputs from all layers
|
|
1769
1768
|
intermediate_outputs = torch.stack(intermediate_outputs)
|
|
1770
1769
|
intermediate_boxes = torch.stack(intermediate_boxes[:-1])
|
|
1771
|
-
intermediate_presence_logits = (
|
|
1772
|
-
torch.stack(intermediate_presence_logits) if intermediate_presence_logits else None
|
|
1773
|
-
)
|
|
1770
|
+
intermediate_presence_logits = torch.stack(intermediate_presence_logits)
|
|
1774
1771
|
|
|
1775
1772
|
return Sam3DETRDecoderOutput(
|
|
1776
1773
|
intermediate_hidden_states=intermediate_outputs,
|
|
@@ -107,7 +107,12 @@ class Sam3TrackerFeedForward(nn.Module):
|
|
|
107
107
|
return hidden_states
|
|
108
108
|
|
|
109
109
|
|
|
110
|
-
@auto_docstring
|
|
110
|
+
@auto_docstring(
|
|
111
|
+
custom_intro="""
|
|
112
|
+
Segment Anything Model 3 (SAM 3) for generating segmentation masks, given an input image and
|
|
113
|
+
input points and labels, boxes, or masks.
|
|
114
|
+
"""
|
|
115
|
+
)
|
|
111
116
|
class Sam3TrackerPreTrainedModel(PreTrainedModel):
|
|
112
117
|
config_class = Sam3TrackerConfig
|
|
113
118
|
base_model_prefix = "sam3_tracker"
|
|
@@ -136,7 +136,12 @@ class Sam3TrackerFeedForward(Sam2FeedForward):
|
|
|
136
136
|
pass
|
|
137
137
|
|
|
138
138
|
|
|
139
|
-
@auto_docstring
|
|
139
|
+
@auto_docstring(
|
|
140
|
+
custom_intro="""
|
|
141
|
+
Segment Anything Model 3 (SAM 3) for generating segmentation masks, given an input image and
|
|
142
|
+
input points and labels, boxes, or masks.
|
|
143
|
+
"""
|
|
144
|
+
)
|
|
140
145
|
class Sam3TrackerPreTrainedModel(Sam2PreTrainedModel):
|
|
141
146
|
@torch.no_grad()
|
|
142
147
|
def _init_weights(self, module):
|
|
@@ -1719,6 +1719,7 @@ class Sam3TrackerVideoModel(Sam3TrackerVideoPreTrainedModel):
|
|
|
1719
1719
|
frame: Optional[torch.Tensor] = None,
|
|
1720
1720
|
reverse: bool = False,
|
|
1721
1721
|
run_mem_encoder: bool = True,
|
|
1722
|
+
**kwargs,
|
|
1722
1723
|
) -> Sam3TrackerVideoSegmentationOutput:
|
|
1723
1724
|
r"""
|
|
1724
1725
|
inference_session (`Sam3TrackerVideoInferenceSession`):
|
|
@@ -1770,6 +1770,7 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
|
|
|
1770
1770
|
output_hidden_states: Optional[bool] = None,
|
|
1771
1771
|
return_dict: Optional[bool] = None,
|
|
1772
1772
|
cache_position: Optional[torch.Tensor] = None,
|
|
1773
|
+
**kwargs,
|
|
1773
1774
|
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
1774
1775
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1775
1776
|
output_hidden_states = (
|
|
@@ -1914,6 +1915,7 @@ class SeamlessM4TTextToUnitModel(SeamlessM4TPreTrainedModel):
|
|
|
1914
1915
|
output_hidden_states: Optional[bool] = None,
|
|
1915
1916
|
return_dict: Optional[bool] = None,
|
|
1916
1917
|
cache_position: Optional[torch.Tensor] = None,
|
|
1918
|
+
**kwargs,
|
|
1917
1919
|
) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
|
|
1918
1920
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1919
1921
|
output_hidden_states = (
|
|
@@ -2035,6 +2037,7 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel,
|
|
|
2035
2037
|
output_hidden_states: Optional[bool] = None,
|
|
2036
2038
|
return_dict: Optional[bool] = None,
|
|
2037
2039
|
cache_position: Optional[torch.Tensor] = None,
|
|
2040
|
+
**kwargs,
|
|
2038
2041
|
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
|
|
2039
2042
|
r"""
|
|
2040
2043
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -2354,7 +2357,7 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel):
|
|
|
2354
2357
|
return input_lengths
|
|
2355
2358
|
|
|
2356
2359
|
def forward(
|
|
2357
|
-
self, input_ids: torch.LongTensor, spkr_id: torch.Tensor, lang_id: torch.Tensor
|
|
2360
|
+
self, input_ids: torch.LongTensor, spkr_id: torch.Tensor, lang_id: torch.Tensor, **kwargs
|
|
2358
2361
|
) -> tuple[torch.Tensor]:
|
|
2359
2362
|
"""
|
|
2360
2363
|
Args:
|
|
@@ -2996,6 +2999,7 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
|
|
|
2996
2999
|
output_hidden_states: Optional[bool] = None,
|
|
2997
3000
|
return_dict: Optional[bool] = None,
|
|
2998
3001
|
cache_position: Optional[torch.Tensor] = None,
|
|
3002
|
+
**kwargs,
|
|
2999
3003
|
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
|
|
3000
3004
|
r"""
|
|
3001
3005
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -60,7 +60,7 @@ class SeamlessM4TTokenizer(TokenizersBackend):
|
|
|
60
60
|
Args:
|
|
61
61
|
vocab (`list` or `dict`, *optional*):
|
|
62
62
|
List of (token, score) tuples or dict mapping tokens to indices. If not provided, uses default vocab.
|
|
63
|
-
merges (`list`, *optional*):
|
|
63
|
+
merges (`str` or `list`, *optional*):
|
|
64
64
|
List of merge rules for BPE model. If not provided, uses empty list.
|
|
65
65
|
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
|
66
66
|
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
|
@@ -104,15 +104,15 @@ class SeamlessM4TTokenizer(TokenizersBackend):
|
|
|
104
104
|
|
|
105
105
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
106
106
|
model_input_names = ["input_ids", "attention_mask"]
|
|
107
|
-
|
|
107
|
+
model = BPE
|
|
108
108
|
|
|
109
|
-
prefix_tokens: list[int] =
|
|
110
|
-
suffix_tokens: list[int] =
|
|
109
|
+
prefix_tokens: list[int] = None
|
|
110
|
+
suffix_tokens: list[int] = None
|
|
111
111
|
|
|
112
112
|
def __init__(
|
|
113
113
|
self,
|
|
114
|
-
vocab: Optional[
|
|
115
|
-
merges: Optional[list] = None,
|
|
114
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
115
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
116
116
|
bos_token="<s>",
|
|
117
117
|
eos_token="</s>",
|
|
118
118
|
sep_token="</s>",
|
|
@@ -126,59 +126,14 @@ class SeamlessM4TTokenizer(TokenizersBackend):
|
|
|
126
126
|
vocab_file=None,
|
|
127
127
|
**kwargs,
|
|
128
128
|
):
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
}
|
|
136
|
-
|
|
137
|
-
# Process vocab - SeamlessM4T uses fairseq vocab alignment: <pad>=0, <unk>=1, <s>=2, </s>=3, then SPM pieces[3:]
|
|
138
|
-
if isinstance(vocab, list):
|
|
139
|
-
# Convert list of (token, score) tuples to dict {token: idx}
|
|
140
|
-
# Check if vocab is already in SeamlessM4T order (pad, unk, s, /s) or tokenizer.json order (unk, s, /s, ...)
|
|
141
|
-
first_tokens = [str(item[0]) if isinstance(item, (list, tuple)) else str(item) for item in vocab[:4]]
|
|
142
|
-
is_seamless_order = (
|
|
143
|
-
len(first_tokens) >= 4
|
|
144
|
-
and first_tokens[0] == str(pad_token)
|
|
145
|
-
and first_tokens[1] == str(unk_token)
|
|
146
|
-
and first_tokens[2] == str(bos_token)
|
|
147
|
-
and first_tokens[3] == str(eos_token)
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
if is_seamless_order:
|
|
151
|
-
# Already in correct order, use list index directly as token ID
|
|
152
|
-
vocab_dict = {}
|
|
153
|
-
for idx, item in enumerate(vocab):
|
|
154
|
-
token = str(item[0]) if isinstance(item, (list, tuple)) else str(item)
|
|
155
|
-
vocab_dict[token] = idx
|
|
156
|
-
self._vocab = vocab_dict
|
|
157
|
-
else:
|
|
158
|
-
# Reorder to fairseq: <pad>, <unk>, <s>, </s>, ... (rest of vocab)
|
|
159
|
-
vocab_dict = {}
|
|
160
|
-
vocab_dict[str(pad_token)] = 0
|
|
161
|
-
vocab_dict[str(unk_token)] = 1
|
|
162
|
-
vocab_dict[str(bos_token)] = 2
|
|
163
|
-
vocab_dict[str(eos_token)] = 3
|
|
164
|
-
|
|
165
|
-
# Add rest of vocab starting from index 4, skipping tokens we already added
|
|
166
|
-
idx = 4
|
|
167
|
-
for item in vocab:
|
|
168
|
-
token = str(item[0]) if isinstance(item, (list, tuple)) else str(item)
|
|
169
|
-
if token not in vocab_dict:
|
|
170
|
-
vocab_dict[token] = idx
|
|
171
|
-
idx += 1
|
|
172
|
-
|
|
173
|
-
self._vocab = vocab_dict
|
|
174
|
-
else:
|
|
175
|
-
self._vocab = vocab
|
|
176
|
-
|
|
177
|
-
if merges is None:
|
|
178
|
-
self._merges = []
|
|
179
|
-
else:
|
|
180
|
-
self._merges = [tuple(merge) if isinstance(merge, list) else merge for merge in merges]
|
|
129
|
+
self._vocab = vocab or {
|
|
130
|
+
str(pad_token): 0,
|
|
131
|
+
str(unk_token): 1,
|
|
132
|
+
str(bos_token): 2,
|
|
133
|
+
str(eos_token): 3,
|
|
134
|
+
}
|
|
181
135
|
|
|
136
|
+
self._merges = merges or []
|
|
182
137
|
self._tokenizer = Tokenizer(
|
|
183
138
|
BPE(
|
|
184
139
|
vocab=self._vocab,
|
|
@@ -216,7 +171,6 @@ class SeamlessM4TTokenizer(TokenizersBackend):
|
|
|
216
171
|
kwargs.setdefault("additional_special_tokens", additional_special_tokens)
|
|
217
172
|
|
|
218
173
|
super().__init__(
|
|
219
|
-
tokenizer_object=self._tokenizer,
|
|
220
174
|
bos_token=bos_token,
|
|
221
175
|
eos_token=eos_token,
|
|
222
176
|
sep_token=sep_token,
|
|
@@ -245,6 +199,20 @@ class SeamlessM4TTokenizer(TokenizersBackend):
|
|
|
245
199
|
|
|
246
200
|
self.set_tgt_lang_special_tokens(self._tgt_lang)
|
|
247
201
|
|
|
202
|
+
@classmethod
|
|
203
|
+
def convert_from_spm_model(cls, vocab, **kwargs):
|
|
204
|
+
"""When converting from spm, offset is needed to account for special tokens."""
|
|
205
|
+
_vocab = {
|
|
206
|
+
"<pad>": 0,
|
|
207
|
+
"<unk>": 1,
|
|
208
|
+
"<s>": 2,
|
|
209
|
+
"</s>": 3,
|
|
210
|
+
}
|
|
211
|
+
for i, token in enumerate(list(vocab.keys())):
|
|
212
|
+
_vocab[token] = i + 1 # offset by 1 to account for special tokens
|
|
213
|
+
kwargs["vocab"] = _vocab
|
|
214
|
+
return kwargs
|
|
215
|
+
|
|
248
216
|
@property
|
|
249
217
|
def src_lang(self) -> str:
|
|
250
218
|
return self._src_lang
|
|
@@ -1812,6 +1812,7 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel):
|
|
|
1812
1812
|
output_hidden_states: Optional[bool] = None,
|
|
1813
1813
|
return_dict: Optional[bool] = None,
|
|
1814
1814
|
cache_position: Optional[torch.Tensor] = None,
|
|
1815
|
+
**kwargs,
|
|
1815
1816
|
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
1816
1817
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1817
1818
|
output_hidden_states = (
|
|
@@ -1995,6 +1996,7 @@ class SeamlessM4Tv2TextToUnitDecoder(SeamlessM4Tv2PreTrainedModel):
|
|
|
1995
1996
|
output_attentions: Optional[bool] = None,
|
|
1996
1997
|
output_hidden_states: Optional[bool] = None,
|
|
1997
1998
|
return_dict: Optional[bool] = None,
|
|
1999
|
+
**kwargs,
|
|
1998
2000
|
) -> Union[tuple, SeamlessM4Tv2TextToUnitDecoderOutput]:
|
|
1999
2001
|
r"""
|
|
2000
2002
|
Args:
|
|
@@ -2122,6 +2124,7 @@ class SeamlessM4Tv2TextToUnitModel(SeamlessM4Tv2PreTrainedModel):
|
|
|
2122
2124
|
output_attentions: Optional[bool] = None,
|
|
2123
2125
|
output_hidden_states: Optional[bool] = None,
|
|
2124
2126
|
return_dict: Optional[bool] = None,
|
|
2127
|
+
**kwargs,
|
|
2125
2128
|
) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
|
|
2126
2129
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
2127
2130
|
output_hidden_states = (
|
|
@@ -2556,7 +2559,7 @@ class SeamlessM4Tv2CodeHifiGan(PreTrainedModel):
|
|
|
2556
2559
|
|
|
2557
2560
|
# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.forward with SeamlessM4T->SeamlessM4Tv2, spkr_id->speaker_id
|
|
2558
2561
|
def forward(
|
|
2559
|
-
self, input_ids: torch.LongTensor, speaker_id: torch.Tensor, lang_id: torch.Tensor
|
|
2562
|
+
self, input_ids: torch.LongTensor, speaker_id: torch.Tensor, lang_id: torch.Tensor, **kwargs
|
|
2560
2563
|
) -> tuple[torch.Tensor]:
|
|
2561
2564
|
"""
|
|
2562
2565
|
Args:
|
|
@@ -3214,6 +3217,7 @@ class SeamlessM4Tv2ForTextToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMixin
|
|
|
3214
3217
|
output_hidden_states: Optional[bool] = None,
|
|
3215
3218
|
return_dict: Optional[bool] = None,
|
|
3216
3219
|
cache_position: Optional[torch.Tensor] = None,
|
|
3220
|
+
**kwargs,
|
|
3217
3221
|
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
|
|
3218
3222
|
r"""
|
|
3219
3223
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -40,7 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
40
40
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
41
41
|
from ...processing_utils import Unpack
|
|
42
42
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
43
|
-
from ...utils.generic import check_model_inputs
|
|
43
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
44
44
|
from .configuration_seed_oss import SeedOssConfig
|
|
45
45
|
|
|
46
46
|
|
|
@@ -350,7 +350,7 @@ class SeedOssRotaryEmbedding(nn.Module):
|
|
|
350
350
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
351
351
|
|
|
352
352
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
353
|
-
with
|
|
353
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
354
354
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
355
355
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
356
356
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -434,6 +434,7 @@ class SegformerModel(SegformerPreTrainedModel):
|
|
|
434
434
|
output_attentions: Optional[bool] = None,
|
|
435
435
|
output_hidden_states: Optional[bool] = None,
|
|
436
436
|
return_dict: Optional[bool] = None,
|
|
437
|
+
**kwargs,
|
|
437
438
|
) -> Union[tuple, BaseModelOutput]:
|
|
438
439
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
439
440
|
output_hidden_states = (
|
|
@@ -486,6 +487,7 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
|
|
|
486
487
|
output_attentions: Optional[bool] = None,
|
|
487
488
|
output_hidden_states: Optional[bool] = None,
|
|
488
489
|
return_dict: Optional[bool] = None,
|
|
490
|
+
**kwargs,
|
|
489
491
|
) -> Union[tuple, SegFormerImageClassifierOutput]:
|
|
490
492
|
r"""
|
|
491
493
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -572,7 +574,7 @@ class SegformerDecodeHead(SegformerPreTrainedModel):
|
|
|
572
574
|
|
|
573
575
|
self.config = config
|
|
574
576
|
|
|
575
|
-
def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
|
|
577
|
+
def forward(self, encoder_hidden_states: torch.FloatTensor, **kwargs) -> torch.Tensor:
|
|
576
578
|
batch_size = encoder_hidden_states[-1].shape[0]
|
|
577
579
|
|
|
578
580
|
all_hidden_states = ()
|
|
@@ -627,6 +629,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
|
|
627
629
|
output_attentions: Optional[bool] = None,
|
|
628
630
|
output_hidden_states: Optional[bool] = None,
|
|
629
631
|
return_dict: Optional[bool] = None,
|
|
632
|
+
**kwargs,
|
|
630
633
|
) -> Union[tuple, SemanticSegmenterOutput]:
|
|
631
634
|
r"""
|
|
632
635
|
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
|
@@ -647,6 +647,7 @@ class SegGptModel(SegGptPreTrainedModel):
|
|
|
647
647
|
output_attentions: Optional[bool] = None,
|
|
648
648
|
output_hidden_states: Optional[bool] = None,
|
|
649
649
|
return_dict: Optional[bool] = None,
|
|
650
|
+
**kwargs,
|
|
650
651
|
) -> Union[tuple, SegGptEncoderOutput]:
|
|
651
652
|
r"""
|
|
652
653
|
prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
@@ -843,6 +844,7 @@ class SegGptForImageSegmentation(SegGptPreTrainedModel):
|
|
|
843
844
|
output_attentions: Optional[bool] = None,
|
|
844
845
|
output_hidden_states: Optional[bool] = None,
|
|
845
846
|
return_dict: Optional[bool] = None,
|
|
847
|
+
**kwargs,
|
|
846
848
|
) -> Union[tuple, SegGptImageSegmentationOutput]:
|
|
847
849
|
r"""
|
|
848
850
|
prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|