transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -19,20 +19,24 @@
|
|
|
19
19
|
# limitations under the License.
|
|
20
20
|
import math
|
|
21
21
|
import warnings
|
|
22
|
+
from collections.abc import Callable
|
|
22
23
|
from dataclasses import dataclass
|
|
23
24
|
|
|
24
25
|
import torch
|
|
26
|
+
import torch.nn as nn
|
|
25
27
|
import torch.nn.functional as F
|
|
26
|
-
from torch import Tensor
|
|
28
|
+
from torch import Tensor
|
|
27
29
|
|
|
28
30
|
from ... import initialization as init
|
|
29
31
|
from ...activations import ACT2CLS, ACT2FN
|
|
32
|
+
from ...backbone_utils import load_backbone
|
|
30
33
|
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
|
31
34
|
from ...modeling_outputs import BaseModelOutput
|
|
32
|
-
from ...modeling_utils import PreTrainedModel
|
|
35
|
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
36
|
+
from ...processing_utils import Unpack
|
|
33
37
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
34
|
-
from ...utils import ModelOutput, auto_docstring,
|
|
35
|
-
from ...utils.
|
|
38
|
+
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int
|
|
39
|
+
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
36
40
|
from .configuration_rt_detr_v2 import RTDetrV2Config
|
|
37
41
|
|
|
38
42
|
|
|
@@ -169,7 +173,7 @@ class RTDetrV2MultiscaleDeformableAttention(nn.Module):
|
|
|
169
173
|
spatial_shapes=None,
|
|
170
174
|
spatial_shapes_list=None,
|
|
171
175
|
level_start_index=None,
|
|
172
|
-
|
|
176
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
173
177
|
):
|
|
174
178
|
# Process inputs up to sampling locations calculation using parent class logic
|
|
175
179
|
if position_embeddings is not None:
|
|
@@ -177,10 +181,10 @@ class RTDetrV2MultiscaleDeformableAttention(nn.Module):
|
|
|
177
181
|
|
|
178
182
|
batch_size, num_queries, _ = hidden_states.shape
|
|
179
183
|
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
+
torch_compilable_check(
|
|
185
|
+
(spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == sequence_length,
|
|
186
|
+
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
|
|
187
|
+
)
|
|
184
188
|
|
|
185
189
|
value = self.value_proj(encoder_hidden_states)
|
|
186
190
|
if attention_mask is not None:
|
|
@@ -220,167 +224,159 @@ class RTDetrV2MultiscaleDeformableAttention(nn.Module):
|
|
|
220
224
|
return output, attention_weights
|
|
221
225
|
|
|
222
226
|
|
|
223
|
-
class
|
|
227
|
+
class RTDetrV2MLP(nn.Module):
|
|
228
|
+
def __init__(self, config: RTDetrV2Config, hidden_size: int, intermediate_size: int, activation_function: str):
|
|
229
|
+
super().__init__()
|
|
230
|
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
|
231
|
+
self.fc2 = nn.Linear(intermediate_size, hidden_size)
|
|
232
|
+
self.activation_fn = ACT2FN[activation_function]
|
|
233
|
+
self.activation_dropout = config.activation_dropout
|
|
234
|
+
self.dropout = config.dropout
|
|
235
|
+
|
|
236
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
237
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
238
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
239
|
+
hidden_states = self.fc2(hidden_states)
|
|
240
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
241
|
+
return hidden_states
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def eager_attention_forward(
|
|
245
|
+
module: nn.Module,
|
|
246
|
+
query: torch.Tensor,
|
|
247
|
+
key: torch.Tensor,
|
|
248
|
+
value: torch.Tensor,
|
|
249
|
+
attention_mask: torch.Tensor | None,
|
|
250
|
+
scaling: float | None = None,
|
|
251
|
+
dropout: float = 0.0,
|
|
252
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
253
|
+
):
|
|
254
|
+
if scaling is None:
|
|
255
|
+
scaling = query.size(-1) ** -0.5
|
|
256
|
+
|
|
257
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
258
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
259
|
+
|
|
260
|
+
if attention_mask is not None:
|
|
261
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
262
|
+
attn_weights = attn_weights + attention_mask
|
|
263
|
+
|
|
264
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
265
|
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
266
|
+
|
|
267
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
268
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
269
|
+
|
|
270
|
+
return attn_output, attn_weights
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class RTDetrV2SelfAttention(nn.Module):
|
|
224
274
|
"""
|
|
225
|
-
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
275
|
+
Multi-headed self-attention from 'Attention Is All You Need' paper.
|
|
226
276
|
|
|
227
|
-
|
|
277
|
+
In RT_DETR_V2, position embeddings are added to both queries and keys (but not values) in self-attention.
|
|
228
278
|
"""
|
|
229
279
|
|
|
230
280
|
def __init__(
|
|
231
281
|
self,
|
|
232
|
-
|
|
233
|
-
|
|
282
|
+
config: RTDetrV2Config,
|
|
283
|
+
hidden_size: int,
|
|
284
|
+
num_attention_heads: int,
|
|
234
285
|
dropout: float = 0.0,
|
|
235
286
|
bias: bool = True,
|
|
236
287
|
):
|
|
237
288
|
super().__init__()
|
|
238
|
-
self.
|
|
239
|
-
self.
|
|
240
|
-
self.dropout = dropout
|
|
241
|
-
self.head_dim = embed_dim // num_heads
|
|
242
|
-
if self.head_dim * num_heads != self.embed_dim:
|
|
243
|
-
raise ValueError(
|
|
244
|
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
245
|
-
f" {num_heads})."
|
|
246
|
-
)
|
|
289
|
+
self.config = config
|
|
290
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
247
291
|
self.scaling = self.head_dim**-0.5
|
|
292
|
+
self.attention_dropout = dropout
|
|
293
|
+
self.is_causal = False
|
|
248
294
|
|
|
249
|
-
self.k_proj = nn.Linear(
|
|
250
|
-
self.v_proj = nn.Linear(
|
|
251
|
-
self.q_proj = nn.Linear(
|
|
252
|
-
self.
|
|
253
|
-
|
|
254
|
-
def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
|
255
|
-
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
256
|
-
|
|
257
|
-
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Tensor | None):
|
|
258
|
-
return tensor if position_embeddings is None else tensor + position_embeddings
|
|
295
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
296
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
297
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
298
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
259
299
|
|
|
260
300
|
def forward(
|
|
261
301
|
self,
|
|
262
302
|
hidden_states: torch.Tensor,
|
|
263
303
|
attention_mask: torch.Tensor | None = None,
|
|
264
304
|
position_embeddings: torch.Tensor | None = None,
|
|
265
|
-
|
|
266
|
-
) -> tuple[torch.Tensor, torch.Tensor
|
|
267
|
-
"""
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
hidden_states_original = hidden_states
|
|
273
|
-
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
|
274
|
-
|
|
275
|
-
# get queries, keys and values
|
|
276
|
-
query_states = self.q_proj(hidden_states) * self.scaling
|
|
277
|
-
key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size)
|
|
278
|
-
value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size)
|
|
279
|
-
|
|
280
|
-
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
|
|
281
|
-
query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape)
|
|
282
|
-
key_states = key_states.view(*proj_shape)
|
|
283
|
-
value_states = value_states.view(*proj_shape)
|
|
284
|
-
|
|
285
|
-
source_len = key_states.size(1)
|
|
286
|
-
|
|
287
|
-
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
|
288
|
-
|
|
289
|
-
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
|
|
290
|
-
raise ValueError(
|
|
291
|
-
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
|
|
292
|
-
f" {attn_weights.size()}"
|
|
293
|
-
)
|
|
294
|
-
|
|
295
|
-
# expand attention_mask
|
|
296
|
-
if attention_mask is not None:
|
|
297
|
-
# [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
298
|
-
attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size())
|
|
299
|
-
|
|
300
|
-
if attention_mask is not None:
|
|
301
|
-
if attention_mask.size() != (batch_size, 1, target_len, source_len):
|
|
302
|
-
raise ValueError(
|
|
303
|
-
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
|
304
|
-
f" {attention_mask.size()}"
|
|
305
|
-
)
|
|
306
|
-
if attention_mask.dtype == torch.bool:
|
|
307
|
-
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
|
|
308
|
-
attention_mask, -torch.inf
|
|
309
|
-
)
|
|
310
|
-
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
|
311
|
-
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
|
312
|
-
|
|
313
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
314
|
-
|
|
315
|
-
if output_attentions:
|
|
316
|
-
# this operation is a bit awkward, but it's required to
|
|
317
|
-
# make sure that attn_weights keeps its gradient.
|
|
318
|
-
# In order to do so, attn_weights have to reshaped
|
|
319
|
-
# twice and have to be reused in the following
|
|
320
|
-
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
|
321
|
-
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
|
322
|
-
else:
|
|
323
|
-
attn_weights_reshaped = None
|
|
324
|
-
|
|
325
|
-
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
305
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
306
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
307
|
+
"""
|
|
308
|
+
Position embeddings are added to both queries and keys (but not values).
|
|
309
|
+
"""
|
|
310
|
+
input_shape = hidden_states.shape[:-1]
|
|
311
|
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
326
312
|
|
|
327
|
-
|
|
313
|
+
query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
|
|
328
314
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
f" {attn_output.size()}"
|
|
333
|
-
)
|
|
315
|
+
query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
316
|
+
key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
317
|
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
334
318
|
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
319
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
320
|
+
self.config._attn_implementation, eager_attention_forward
|
|
321
|
+
)
|
|
338
322
|
|
|
339
|
-
attn_output =
|
|
323
|
+
attn_output, attn_weights = attention_interface(
|
|
324
|
+
self,
|
|
325
|
+
query_states,
|
|
326
|
+
key_states,
|
|
327
|
+
value_states,
|
|
328
|
+
attention_mask,
|
|
329
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
330
|
+
scaling=self.scaling,
|
|
331
|
+
**kwargs,
|
|
332
|
+
)
|
|
340
333
|
|
|
341
|
-
|
|
334
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
335
|
+
attn_output = self.o_proj(attn_output)
|
|
336
|
+
return attn_output, attn_weights
|
|
342
337
|
|
|
343
338
|
|
|
344
339
|
class RTDetrV2DecoderLayer(nn.Module):
|
|
345
340
|
def __init__(self, config: RTDetrV2Config):
|
|
346
341
|
super().__init__()
|
|
342
|
+
self.hidden_size = config.d_model
|
|
343
|
+
|
|
347
344
|
# self-attention
|
|
348
|
-
self.self_attn =
|
|
349
|
-
|
|
350
|
-
|
|
345
|
+
self.self_attn = RTDetrV2SelfAttention(
|
|
346
|
+
config=config,
|
|
347
|
+
hidden_size=self.hidden_size,
|
|
348
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
351
349
|
dropout=config.attention_dropout,
|
|
352
350
|
)
|
|
353
351
|
self.dropout = config.dropout
|
|
354
|
-
self.activation_fn = ACT2FN[config.decoder_activation_function]
|
|
355
|
-
self.activation_dropout = config.activation_dropout
|
|
356
352
|
|
|
357
|
-
self.self_attn_layer_norm = nn.LayerNorm(
|
|
353
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
358
354
|
# override only the encoder attention module with v2 version
|
|
359
355
|
self.encoder_attn = RTDetrV2MultiscaleDeformableAttention(config)
|
|
360
|
-
self.encoder_attn_layer_norm = nn.LayerNorm(
|
|
356
|
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
361
357
|
# feedforward neural networks
|
|
362
|
-
self.
|
|
363
|
-
self.
|
|
364
|
-
self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
|
|
358
|
+
self.mlp = RTDetrV2MLP(config, self.hidden_size, config.decoder_ffn_dim, config.decoder_activation_function)
|
|
359
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
365
360
|
|
|
366
361
|
def forward(
|
|
367
362
|
self,
|
|
368
363
|
hidden_states: torch.Tensor,
|
|
369
|
-
|
|
364
|
+
object_queries_position_embeddings: torch.Tensor | None = None,
|
|
370
365
|
reference_points=None,
|
|
371
366
|
spatial_shapes=None,
|
|
372
367
|
spatial_shapes_list=None,
|
|
373
368
|
level_start_index=None,
|
|
374
369
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
375
370
|
encoder_attention_mask: torch.Tensor | None = None,
|
|
376
|
-
|
|
377
|
-
):
|
|
371
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
372
|
+
) -> torch.Tensor:
|
|
378
373
|
"""
|
|
379
374
|
Args:
|
|
380
375
|
hidden_states (`torch.FloatTensor`):
|
|
381
|
-
Input to the layer of shape `(
|
|
382
|
-
|
|
383
|
-
Position embeddings
|
|
376
|
+
Input to the layer of shape `(batch, seq_len, hidden_size)`.
|
|
377
|
+
object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
378
|
+
Position embeddings for the object query slots. These are added to both queries and keys
|
|
379
|
+
in the self-attention layer (not values).
|
|
384
380
|
reference_points (`torch.FloatTensor`, *optional*):
|
|
385
381
|
Reference points.
|
|
386
382
|
spatial_shapes (`torch.LongTensor`, *optional*):
|
|
@@ -388,63 +384,51 @@ class RTDetrV2DecoderLayer(nn.Module):
|
|
|
388
384
|
level_start_index (`torch.LongTensor`, *optional*):
|
|
389
385
|
Level start index.
|
|
390
386
|
encoder_hidden_states (`torch.FloatTensor`):
|
|
391
|
-
cross attention input to the layer of shape `(
|
|
387
|
+
cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
392
388
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
393
389
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
394
390
|
values.
|
|
395
|
-
output_attentions (`bool`, *optional*):
|
|
396
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
397
|
-
returned tensors for more detail.
|
|
398
391
|
"""
|
|
399
392
|
residual = hidden_states
|
|
400
393
|
|
|
401
394
|
# Self Attention
|
|
402
|
-
hidden_states,
|
|
395
|
+
hidden_states, _ = self.self_attn(
|
|
403
396
|
hidden_states=hidden_states,
|
|
404
397
|
attention_mask=encoder_attention_mask,
|
|
405
|
-
position_embeddings=
|
|
406
|
-
|
|
398
|
+
position_embeddings=object_queries_position_embeddings,
|
|
399
|
+
**kwargs,
|
|
407
400
|
)
|
|
408
401
|
|
|
409
402
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
410
403
|
hidden_states = residual + hidden_states
|
|
411
404
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
412
405
|
|
|
413
|
-
|
|
406
|
+
residual = hidden_states
|
|
414
407
|
|
|
415
408
|
# Cross-Attention
|
|
416
|
-
|
|
417
|
-
hidden_states, cross_attn_weights = self.encoder_attn(
|
|
409
|
+
hidden_states, _ = self.encoder_attn(
|
|
418
410
|
hidden_states=hidden_states,
|
|
419
411
|
encoder_hidden_states=encoder_hidden_states,
|
|
420
|
-
position_embeddings=
|
|
412
|
+
position_embeddings=object_queries_position_embeddings,
|
|
421
413
|
reference_points=reference_points,
|
|
422
414
|
spatial_shapes=spatial_shapes,
|
|
423
415
|
spatial_shapes_list=spatial_shapes_list,
|
|
424
416
|
level_start_index=level_start_index,
|
|
425
|
-
|
|
417
|
+
**kwargs,
|
|
426
418
|
)
|
|
427
419
|
|
|
428
420
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
429
|
-
hidden_states =
|
|
421
|
+
hidden_states = residual + hidden_states
|
|
430
422
|
|
|
431
423
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
432
424
|
|
|
433
425
|
# Fully Connected
|
|
434
426
|
residual = hidden_states
|
|
435
|
-
hidden_states = self.
|
|
436
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
437
|
-
hidden_states = self.fc2(hidden_states)
|
|
438
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
427
|
+
hidden_states = self.mlp(hidden_states)
|
|
439
428
|
hidden_states = residual + hidden_states
|
|
440
429
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
441
430
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
if output_attentions:
|
|
445
|
-
outputs += (self_attn_weights, cross_attn_weights)
|
|
446
|
-
|
|
447
|
-
return outputs
|
|
431
|
+
return hidden_states
|
|
448
432
|
|
|
449
433
|
|
|
450
434
|
@auto_docstring
|
|
@@ -454,6 +438,10 @@ class RTDetrV2PreTrainedModel(PreTrainedModel):
|
|
|
454
438
|
main_input_name = "pixel_values"
|
|
455
439
|
input_modalities = ("image",)
|
|
456
440
|
_no_split_modules = [r"RTDetrV2HybridEncoder", r"RTDetrV2DecoderLayer"]
|
|
441
|
+
_supports_sdpa = True
|
|
442
|
+
_supports_flash_attn = True
|
|
443
|
+
_supports_attention_backend = True
|
|
444
|
+
_supports_flex_attn = True
|
|
457
445
|
|
|
458
446
|
@torch.no_grad()
|
|
459
447
|
def _init_weights(self, module):
|
|
@@ -568,12 +556,18 @@ def inverse_sigmoid(x, eps=1e-5):
|
|
|
568
556
|
|
|
569
557
|
|
|
570
558
|
class RTDetrV2Decoder(RTDetrV2PreTrainedModel):
|
|
559
|
+
_can_record_outputs = {
|
|
560
|
+
"hidden_states": RTDetrV2DecoderLayer,
|
|
561
|
+
"attentions": RTDetrV2SelfAttention,
|
|
562
|
+
"cross_attentions": RTDetrV2MultiscaleDeformableAttention,
|
|
563
|
+
}
|
|
564
|
+
|
|
571
565
|
def __init__(self, config: RTDetrV2Config):
|
|
572
566
|
super().__init__(config)
|
|
573
567
|
|
|
574
568
|
self.dropout = config.dropout
|
|
575
569
|
self.layers = nn.ModuleList([RTDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)])
|
|
576
|
-
self.query_pos_head = RTDetrV2MLPPredictionHead(
|
|
570
|
+
self.query_pos_head = RTDetrV2MLPPredictionHead(4, 2 * config.d_model, config.d_model, num_layers=2)
|
|
577
571
|
|
|
578
572
|
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
|
|
579
573
|
self.bbox_embed = None
|
|
@@ -582,21 +576,17 @@ class RTDetrV2Decoder(RTDetrV2PreTrainedModel):
|
|
|
582
576
|
# Initialize weights and apply final processing
|
|
583
577
|
self.post_init()
|
|
584
578
|
|
|
579
|
+
@check_model_inputs()
|
|
585
580
|
def forward(
|
|
586
581
|
self,
|
|
587
582
|
inputs_embeds=None,
|
|
588
583
|
encoder_hidden_states=None,
|
|
589
584
|
encoder_attention_mask=None,
|
|
590
|
-
position_embeddings=None,
|
|
591
585
|
reference_points=None,
|
|
592
586
|
spatial_shapes=None,
|
|
593
587
|
spatial_shapes_list=None,
|
|
594
588
|
level_start_index=None,
|
|
595
|
-
|
|
596
|
-
output_attentions=None,
|
|
597
|
-
output_hidden_states=None,
|
|
598
|
-
return_dict=None,
|
|
599
|
-
**kwargs,
|
|
589
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
600
590
|
):
|
|
601
591
|
r"""
|
|
602
592
|
Args:
|
|
@@ -610,39 +600,17 @@ class RTDetrV2Decoder(RTDetrV2PreTrainedModel):
|
|
|
610
600
|
in `[0, 1]`:
|
|
611
601
|
- 1 for pixels that are real (i.e. **not masked**),
|
|
612
602
|
- 0 for pixels that are padding (i.e. **masked**).
|
|
613
|
-
position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
614
|
-
Position embeddings that are added to the queries and keys in each self-attention layer.
|
|
615
603
|
reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
|
|
616
604
|
Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
|
|
617
605
|
spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
|
|
618
606
|
Spatial shapes of the feature maps.
|
|
619
607
|
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
|
|
620
608
|
Indexes for the start of each feature level. In range `[0, sequence_length]`.
|
|
621
|
-
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
|
|
622
|
-
Ratio of valid area in each feature level.
|
|
623
|
-
|
|
624
|
-
output_attentions (`bool`, *optional*):
|
|
625
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
626
|
-
returned tensors for more detail.
|
|
627
|
-
output_hidden_states (`bool`, *optional*):
|
|
628
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
629
|
-
for more detail.
|
|
630
|
-
return_dict (`bool`, *optional*):
|
|
631
|
-
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
632
609
|
"""
|
|
633
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
634
|
-
output_hidden_states = (
|
|
635
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
636
|
-
)
|
|
637
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
638
|
-
|
|
639
610
|
if inputs_embeds is not None:
|
|
640
611
|
hidden_states = inputs_embeds
|
|
641
612
|
|
|
642
613
|
# decoder layers
|
|
643
|
-
all_hidden_states = () if output_hidden_states else None
|
|
644
|
-
all_self_attns = () if output_attentions else None
|
|
645
|
-
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
646
614
|
intermediate = ()
|
|
647
615
|
intermediate_reference_points = ()
|
|
648
616
|
intermediate_logits = ()
|
|
@@ -652,25 +620,20 @@ class RTDetrV2Decoder(RTDetrV2PreTrainedModel):
|
|
|
652
620
|
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L252
|
|
653
621
|
for idx, decoder_layer in enumerate(self.layers):
|
|
654
622
|
reference_points_input = reference_points.unsqueeze(2)
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
if output_hidden_states:
|
|
658
|
-
all_hidden_states += (hidden_states,)
|
|
623
|
+
object_queries_position_embeddings = self.query_pos_head(reference_points)
|
|
659
624
|
|
|
660
|
-
|
|
625
|
+
hidden_states = decoder_layer(
|
|
661
626
|
hidden_states,
|
|
662
|
-
|
|
627
|
+
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
663
628
|
encoder_hidden_states=encoder_hidden_states,
|
|
664
629
|
reference_points=reference_points_input,
|
|
665
630
|
spatial_shapes=spatial_shapes,
|
|
666
631
|
spatial_shapes_list=spatial_shapes_list,
|
|
667
632
|
level_start_index=level_start_index,
|
|
668
633
|
encoder_attention_mask=encoder_attention_mask,
|
|
669
|
-
|
|
634
|
+
**kwargs,
|
|
670
635
|
)
|
|
671
636
|
|
|
672
|
-
hidden_states = layer_outputs[0]
|
|
673
|
-
|
|
674
637
|
# hack implementation for iterative bounding box refinement
|
|
675
638
|
if self.bbox_embed is not None:
|
|
676
639
|
predicted_corners = self.bbox_embed[idx](hidden_states)
|
|
@@ -686,44 +649,17 @@ class RTDetrV2Decoder(RTDetrV2PreTrainedModel):
|
|
|
686
649
|
logits = self.class_embed[idx](hidden_states)
|
|
687
650
|
intermediate_logits += (logits,)
|
|
688
651
|
|
|
689
|
-
if output_attentions:
|
|
690
|
-
all_self_attns += (layer_outputs[1],)
|
|
691
|
-
|
|
692
|
-
if encoder_hidden_states is not None:
|
|
693
|
-
all_cross_attentions += (layer_outputs[2],)
|
|
694
|
-
|
|
695
652
|
# Keep batch_size as first dimension
|
|
696
653
|
intermediate = torch.stack(intermediate, dim=1)
|
|
697
654
|
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
|
|
698
655
|
if self.class_embed is not None:
|
|
699
656
|
intermediate_logits = torch.stack(intermediate_logits, dim=1)
|
|
700
657
|
|
|
701
|
-
# add hidden states from the last decoder layer
|
|
702
|
-
if output_hidden_states:
|
|
703
|
-
all_hidden_states += (hidden_states,)
|
|
704
|
-
|
|
705
|
-
if not return_dict:
|
|
706
|
-
return tuple(
|
|
707
|
-
v
|
|
708
|
-
for v in [
|
|
709
|
-
hidden_states,
|
|
710
|
-
intermediate,
|
|
711
|
-
intermediate_logits,
|
|
712
|
-
intermediate_reference_points,
|
|
713
|
-
all_hidden_states,
|
|
714
|
-
all_self_attns,
|
|
715
|
-
all_cross_attentions,
|
|
716
|
-
]
|
|
717
|
-
if v is not None
|
|
718
|
-
)
|
|
719
658
|
return RTDetrV2DecoderOutput(
|
|
720
659
|
last_hidden_state=hidden_states,
|
|
721
660
|
intermediate_hidden_states=intermediate,
|
|
722
661
|
intermediate_logits=intermediate_logits,
|
|
723
662
|
intermediate_reference_points=intermediate_reference_points,
|
|
724
|
-
hidden_states=all_hidden_states,
|
|
725
|
-
attentions=all_self_attns,
|
|
726
|
-
cross_attentions=all_cross_attentions,
|
|
727
663
|
)
|
|
728
664
|
|
|
729
665
|
|
|
@@ -905,50 +841,46 @@ class RTDetrV2EncoderLayer(nn.Module):
|
|
|
905
841
|
def __init__(self, config: RTDetrV2Config):
|
|
906
842
|
super().__init__()
|
|
907
843
|
self.normalize_before = config.normalize_before
|
|
844
|
+
self.hidden_size = config.encoder_hidden_dim
|
|
908
845
|
|
|
909
846
|
# self-attention
|
|
910
|
-
self.self_attn =
|
|
911
|
-
|
|
912
|
-
|
|
847
|
+
self.self_attn = RTDetrV2SelfAttention(
|
|
848
|
+
config=config,
|
|
849
|
+
hidden_size=self.hidden_size,
|
|
850
|
+
num_attention_heads=config.num_attention_heads,
|
|
913
851
|
dropout=config.dropout,
|
|
914
852
|
)
|
|
915
|
-
self.self_attn_layer_norm = nn.LayerNorm(
|
|
853
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
916
854
|
self.dropout = config.dropout
|
|
917
|
-
self.
|
|
918
|
-
self.
|
|
919
|
-
self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim)
|
|
920
|
-
self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim)
|
|
921
|
-
self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
|
|
855
|
+
self.mlp = RTDetrV2MLP(config, self.hidden_size, config.encoder_ffn_dim, config.encoder_activation_function)
|
|
856
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
922
857
|
|
|
923
858
|
def forward(
|
|
924
859
|
self,
|
|
925
860
|
hidden_states: torch.Tensor,
|
|
926
861
|
attention_mask: torch.Tensor,
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
):
|
|
862
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
863
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
864
|
+
) -> torch.Tensor:
|
|
931
865
|
"""
|
|
932
866
|
Args:
|
|
933
|
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len,
|
|
867
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
934
868
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
935
869
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
936
870
|
values.
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
941
|
-
returned tensors for more detail.
|
|
871
|
+
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
872
|
+
Spatial position embeddings (2D positional encodings of image locations), to be added to both
|
|
873
|
+
the queries and keys in self-attention (but not to values).
|
|
942
874
|
"""
|
|
943
875
|
residual = hidden_states
|
|
944
876
|
if self.normalize_before:
|
|
945
877
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
946
878
|
|
|
947
|
-
hidden_states,
|
|
879
|
+
hidden_states, _ = self.self_attn(
|
|
948
880
|
hidden_states=hidden_states,
|
|
949
881
|
attention_mask=attention_mask,
|
|
950
|
-
position_embeddings=
|
|
951
|
-
|
|
882
|
+
position_embeddings=spatial_position_embeddings,
|
|
883
|
+
**kwargs,
|
|
952
884
|
)
|
|
953
885
|
|
|
954
886
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
@@ -960,12 +892,7 @@ class RTDetrV2EncoderLayer(nn.Module):
|
|
|
960
892
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
961
893
|
residual = hidden_states
|
|
962
894
|
|
|
963
|
-
hidden_states = self.
|
|
964
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
965
|
-
|
|
966
|
-
hidden_states = self.fc2(hidden_states)
|
|
967
|
-
|
|
968
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
895
|
+
hidden_states = self.mlp(hidden_states)
|
|
969
896
|
|
|
970
897
|
hidden_states = residual + hidden_states
|
|
971
898
|
if not self.normalize_before:
|
|
@@ -976,12 +903,7 @@ class RTDetrV2EncoderLayer(nn.Module):
|
|
|
976
903
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
977
904
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
978
905
|
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
if output_attentions:
|
|
982
|
-
outputs += (attn_weights,)
|
|
983
|
-
|
|
984
|
-
return outputs
|
|
906
|
+
return hidden_states
|
|
985
907
|
|
|
986
908
|
|
|
987
909
|
class RTDetrV2RepVggBlock(nn.Module):
|
|
@@ -1032,35 +954,119 @@ class RTDetrV2CSPRepLayer(nn.Module):
|
|
|
1032
954
|
return self.conv3(hidden_state_1 + hidden_state_2)
|
|
1033
955
|
|
|
1034
956
|
|
|
1035
|
-
class
|
|
957
|
+
class RTDetrV2SinePositionEmbedding(nn.Module):
|
|
958
|
+
"""
|
|
959
|
+
2D sinusoidal position embedding used in RT-DETR hybrid encoder.
|
|
960
|
+
"""
|
|
961
|
+
|
|
962
|
+
def __init__(self, embed_dim: int = 256, temperature: int = 10000):
|
|
963
|
+
super().__init__()
|
|
964
|
+
self.embed_dim = embed_dim
|
|
965
|
+
self.temperature = temperature
|
|
966
|
+
|
|
967
|
+
@compile_compatible_method_lru_cache(maxsize=32)
|
|
968
|
+
def forward(
|
|
969
|
+
self,
|
|
970
|
+
width: int,
|
|
971
|
+
height: int,
|
|
972
|
+
device: torch.device | str,
|
|
973
|
+
dtype: torch.dtype,
|
|
974
|
+
) -> torch.Tensor:
|
|
975
|
+
"""
|
|
976
|
+
Generate 2D sinusoidal position embeddings.
|
|
977
|
+
|
|
978
|
+
Returns:
|
|
979
|
+
Position embeddings of shape (1, height*width, embed_dim)
|
|
980
|
+
"""
|
|
981
|
+
grid_w = torch.arange(torch_int(width), device=device).to(dtype)
|
|
982
|
+
grid_h = torch.arange(torch_int(height), device=device).to(dtype)
|
|
983
|
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
|
984
|
+
if self.embed_dim % 4 != 0:
|
|
985
|
+
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
|
|
986
|
+
pos_dim = self.embed_dim // 4
|
|
987
|
+
omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
|
|
988
|
+
omega = 1.0 / (self.temperature**omega)
|
|
989
|
+
|
|
990
|
+
out_w = grid_w.flatten()[..., None] @ omega[None]
|
|
991
|
+
out_h = grid_h.flatten()[..., None] @ omega[None]
|
|
992
|
+
|
|
993
|
+
return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
class RTDetrV2AIFILayer(nn.Module):
|
|
997
|
+
"""
|
|
998
|
+
AIFI (Attention-based Intra-scale Feature Interaction) layer used in RT-DETR hybrid encoder.
|
|
999
|
+
"""
|
|
1000
|
+
|
|
1036
1001
|
def __init__(self, config: RTDetrV2Config):
|
|
1037
1002
|
super().__init__()
|
|
1003
|
+
self.config = config
|
|
1004
|
+
self.encoder_hidden_dim = config.encoder_hidden_dim
|
|
1005
|
+
self.eval_size = config.eval_size
|
|
1038
1006
|
|
|
1007
|
+
self.position_embedding = RTDetrV2SinePositionEmbedding(
|
|
1008
|
+
embed_dim=self.encoder_hidden_dim,
|
|
1009
|
+
temperature=config.positional_encoding_temperature,
|
|
1010
|
+
)
|
|
1039
1011
|
self.layers = nn.ModuleList([RTDetrV2EncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
1040
1012
|
|
|
1041
|
-
def forward(
|
|
1042
|
-
|
|
1013
|
+
def forward(
|
|
1014
|
+
self,
|
|
1015
|
+
hidden_states: torch.Tensor,
|
|
1016
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1017
|
+
) -> torch.Tensor:
|
|
1018
|
+
"""
|
|
1019
|
+
Args:
|
|
1020
|
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
|
|
1021
|
+
Feature map to process.
|
|
1022
|
+
"""
|
|
1023
|
+
batch_size = hidden_states.shape[0]
|
|
1024
|
+
height, width = hidden_states.shape[2:]
|
|
1025
|
+
|
|
1026
|
+
hidden_states = hidden_states.flatten(2).permute(0, 2, 1)
|
|
1027
|
+
|
|
1028
|
+
if self.training or self.eval_size is None:
|
|
1029
|
+
pos_embed = self.position_embedding(
|
|
1030
|
+
width=width,
|
|
1031
|
+
height=height,
|
|
1032
|
+
device=hidden_states.device,
|
|
1033
|
+
dtype=hidden_states.dtype,
|
|
1034
|
+
)
|
|
1035
|
+
else:
|
|
1036
|
+
pos_embed = None
|
|
1037
|
+
|
|
1043
1038
|
for layer in self.layers:
|
|
1044
1039
|
hidden_states = layer(
|
|
1045
1040
|
hidden_states,
|
|
1046
|
-
attention_mask=
|
|
1047
|
-
|
|
1048
|
-
|
|
1041
|
+
attention_mask=None,
|
|
1042
|
+
spatial_position_embeddings=pos_embed,
|
|
1043
|
+
**kwargs,
|
|
1049
1044
|
)
|
|
1045
|
+
|
|
1046
|
+
hidden_states = (
|
|
1047
|
+
hidden_states.permute(0, 2, 1).reshape(batch_size, self.encoder_hidden_dim, height, width).contiguous()
|
|
1048
|
+
)
|
|
1049
|
+
|
|
1050
1050
|
return hidden_states
|
|
1051
1051
|
|
|
1052
1052
|
|
|
1053
|
-
class RTDetrV2HybridEncoder(
|
|
1053
|
+
class RTDetrV2HybridEncoder(RTDetrV2PreTrainedModel):
|
|
1054
1054
|
"""
|
|
1055
|
-
|
|
1056
|
-
(FPN) and a bottom-up Path Aggregation Network (PAN).
|
|
1055
|
+
Hybrid encoder consisting of AIFI (Attention-based Intra-scale Feature Interaction) layers,
|
|
1056
|
+
a top-down Feature Pyramid Network (FPN) and a bottom-up Path Aggregation Network (PAN).
|
|
1057
|
+
More details on the paper: https://huggingface.co/papers/2304.08069
|
|
1057
1058
|
|
|
1058
1059
|
Args:
|
|
1059
1060
|
config: RTDetrV2Config
|
|
1060
1061
|
"""
|
|
1061
1062
|
|
|
1063
|
+
_can_record_outputs = {
|
|
1064
|
+
"hidden_states": RTDetrV2AIFILayer,
|
|
1065
|
+
"attentions": RTDetrV2SelfAttention,
|
|
1066
|
+
}
|
|
1067
|
+
|
|
1062
1068
|
def __init__(self, config: RTDetrV2Config):
|
|
1063
|
-
super().__init__()
|
|
1069
|
+
super().__init__(config)
|
|
1064
1070
|
self.config = config
|
|
1065
1071
|
self.in_channels = config.encoder_in_channels
|
|
1066
1072
|
self.feat_strides = config.feat_strides
|
|
@@ -1072,10 +1078,9 @@ class RTDetrV2HybridEncoder(nn.Module):
|
|
|
1072
1078
|
self.out_strides = self.feat_strides
|
|
1073
1079
|
self.num_fpn_stages = len(self.in_channels) - 1
|
|
1074
1080
|
self.num_pan_stages = len(self.in_channels) - 1
|
|
1075
|
-
activation = config.activation_function
|
|
1076
1081
|
|
|
1077
|
-
#
|
|
1078
|
-
self.
|
|
1082
|
+
# AIFI (Attention-based Intra-scale Feature Interaction) layers
|
|
1083
|
+
self.aifi = nn.ModuleList([RTDetrV2AIFILayer(config) for _ in range(len(self.encode_proj_layers))])
|
|
1079
1084
|
|
|
1080
1085
|
# top-down FPN
|
|
1081
1086
|
self.lateral_convs = nn.ModuleList()
|
|
@@ -1087,7 +1092,7 @@ class RTDetrV2HybridEncoder(nn.Module):
|
|
|
1087
1092
|
out_channels=self.encoder_hidden_dim,
|
|
1088
1093
|
kernel_size=1,
|
|
1089
1094
|
stride=1,
|
|
1090
|
-
activation=
|
|
1095
|
+
activation=config.activation_function,
|
|
1091
1096
|
)
|
|
1092
1097
|
fpn_block = RTDetrV2CSPRepLayer(config)
|
|
1093
1098
|
self.lateral_convs.append(lateral_conv)
|
|
@@ -1103,118 +1108,36 @@ class RTDetrV2HybridEncoder(nn.Module):
|
|
|
1103
1108
|
out_channels=self.encoder_hidden_dim,
|
|
1104
1109
|
kernel_size=3,
|
|
1105
1110
|
stride=2,
|
|
1106
|
-
activation=
|
|
1111
|
+
activation=config.activation_function,
|
|
1107
1112
|
)
|
|
1108
1113
|
pan_block = RTDetrV2CSPRepLayer(config)
|
|
1109
1114
|
self.downsample_convs.append(downsample_conv)
|
|
1110
1115
|
self.pan_blocks.append(pan_block)
|
|
1111
1116
|
|
|
1112
|
-
|
|
1113
|
-
def build_2d_sincos_position_embedding(
|
|
1114
|
-
width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
|
|
1115
|
-
):
|
|
1116
|
-
grid_w = torch.arange(torch_int(width), device=device).to(dtype)
|
|
1117
|
-
grid_h = torch.arange(torch_int(height), device=device).to(dtype)
|
|
1118
|
-
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
|
1119
|
-
if embed_dim % 4 != 0:
|
|
1120
|
-
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
|
|
1121
|
-
pos_dim = embed_dim // 4
|
|
1122
|
-
omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
|
|
1123
|
-
omega = 1.0 / (temperature**omega)
|
|
1124
|
-
|
|
1125
|
-
out_w = grid_w.flatten()[..., None] @ omega[None]
|
|
1126
|
-
out_h = grid_h.flatten()[..., None] @ omega[None]
|
|
1127
|
-
|
|
1128
|
-
return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
|
|
1117
|
+
self.post_init()
|
|
1129
1118
|
|
|
1119
|
+
@check_model_inputs(tie_last_hidden_states=False)
|
|
1130
1120
|
def forward(
|
|
1131
1121
|
self,
|
|
1132
1122
|
inputs_embeds=None,
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
spatial_shapes=None,
|
|
1136
|
-
level_start_index=None,
|
|
1137
|
-
valid_ratios=None,
|
|
1138
|
-
output_attentions=None,
|
|
1139
|
-
output_hidden_states=None,
|
|
1140
|
-
return_dict=None,
|
|
1141
|
-
):
|
|
1123
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1124
|
+
) -> BaseModelOutput:
|
|
1142
1125
|
r"""
|
|
1143
1126
|
Args:
|
|
1144
1127
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
1145
1128
|
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
|
1146
|
-
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1147
|
-
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
|
1148
|
-
- 1 for pixel features that are real (i.e. **not masked**),
|
|
1149
|
-
- 0 for pixel features that are padding (i.e. **masked**).
|
|
1150
|
-
[What are attention masks?](../glossary#attention-mask)
|
|
1151
|
-
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
1152
|
-
Position embeddings that are added to the queries and keys in each self-attention layer.
|
|
1153
|
-
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
|
|
1154
|
-
Spatial shapes of each feature map.
|
|
1155
|
-
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
|
|
1156
|
-
Starting index of each feature map.
|
|
1157
|
-
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
|
|
1158
|
-
Ratio of valid area in each feature level.
|
|
1159
|
-
output_attentions (`bool`, *optional*):
|
|
1160
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
1161
|
-
returned tensors for more detail.
|
|
1162
|
-
output_hidden_states (`bool`, *optional*):
|
|
1163
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
1164
|
-
for more detail.
|
|
1165
|
-
return_dict (`bool`, *optional*):
|
|
1166
|
-
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
1167
1129
|
"""
|
|
1168
|
-
|
|
1169
|
-
output_hidden_states = (
|
|
1170
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1171
|
-
)
|
|
1172
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1173
|
-
|
|
1174
|
-
hidden_states = inputs_embeds
|
|
1130
|
+
feature_maps = inputs_embeds
|
|
1175
1131
|
|
|
1176
|
-
|
|
1177
|
-
all_attentions = () if output_attentions else None
|
|
1178
|
-
|
|
1179
|
-
# encoder
|
|
1132
|
+
# AIFI: Apply transformer encoder to specified feature levels
|
|
1180
1133
|
if self.config.encoder_layers > 0:
|
|
1181
1134
|
for i, enc_ind in enumerate(self.encode_proj_layers):
|
|
1182
|
-
|
|
1183
|
-
encoder_states = encoder_states + (hidden_states[enc_ind],)
|
|
1184
|
-
height, width = hidden_states[enc_ind].shape[2:]
|
|
1185
|
-
# flatten [batch, channel, height, width] to [batch, height*width, channel]
|
|
1186
|
-
src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
|
|
1187
|
-
if self.training or self.eval_size is None:
|
|
1188
|
-
pos_embed = self.build_2d_sincos_position_embedding(
|
|
1189
|
-
width,
|
|
1190
|
-
height,
|
|
1191
|
-
self.encoder_hidden_dim,
|
|
1192
|
-
self.positional_encoding_temperature,
|
|
1193
|
-
device=src_flatten.device,
|
|
1194
|
-
dtype=src_flatten.dtype,
|
|
1195
|
-
)
|
|
1196
|
-
else:
|
|
1197
|
-
pos_embed = None
|
|
1198
|
-
|
|
1199
|
-
layer_outputs = self.encoder[i](
|
|
1200
|
-
src_flatten,
|
|
1201
|
-
pos_embed=pos_embed,
|
|
1202
|
-
output_attentions=output_attentions,
|
|
1203
|
-
)
|
|
1204
|
-
hidden_states[enc_ind] = (
|
|
1205
|
-
layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
|
|
1206
|
-
)
|
|
1207
|
-
|
|
1208
|
-
if output_attentions:
|
|
1209
|
-
all_attentions = all_attentions + (layer_outputs[1],)
|
|
1210
|
-
|
|
1211
|
-
if output_hidden_states:
|
|
1212
|
-
encoder_states = encoder_states + (hidden_states[enc_ind],)
|
|
1135
|
+
feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs)
|
|
1213
1136
|
|
|
1214
1137
|
# top-down FPN
|
|
1215
|
-
fpn_feature_maps = [
|
|
1138
|
+
fpn_feature_maps = [feature_maps[-1]]
|
|
1216
1139
|
for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
|
|
1217
|
-
backbone_feature_map =
|
|
1140
|
+
backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1]
|
|
1218
1141
|
top_fpn_feature_map = fpn_feature_maps[-1]
|
|
1219
1142
|
# apply lateral block
|
|
1220
1143
|
top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
|
|
@@ -1237,11 +1160,7 @@ class RTDetrV2HybridEncoder(nn.Module):
|
|
|
1237
1160
|
new_pan_feature_map = pan_block(fused_feature_map)
|
|
1238
1161
|
pan_feature_maps.append(new_pan_feature_map)
|
|
1239
1162
|
|
|
1240
|
-
|
|
1241
|
-
return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None)
|
|
1242
|
-
return BaseModelOutput(
|
|
1243
|
-
last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions
|
|
1244
|
-
)
|
|
1163
|
+
return BaseModelOutput(last_hidden_state=pan_feature_maps)
|
|
1245
1164
|
|
|
1246
1165
|
|
|
1247
1166
|
def get_contrastive_denoising_training_group(
|
|
@@ -1384,8 +1303,8 @@ class RTDetrV2Model(RTDetrV2PreTrainedModel):
|
|
|
1384
1303
|
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/hybrid_encoder.py#L212
|
|
1385
1304
|
num_backbone_outs = len(intermediate_channel_sizes)
|
|
1386
1305
|
encoder_input_proj_list = []
|
|
1387
|
-
for
|
|
1388
|
-
in_channels = intermediate_channel_sizes[
|
|
1306
|
+
for i in range(num_backbone_outs):
|
|
1307
|
+
in_channels = intermediate_channel_sizes[i]
|
|
1389
1308
|
encoder_input_proj_list.append(
|
|
1390
1309
|
nn.Sequential(
|
|
1391
1310
|
nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
|
|
@@ -1413,7 +1332,7 @@ class RTDetrV2Model(RTDetrV2PreTrainedModel):
|
|
|
1413
1332
|
nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
|
|
1414
1333
|
)
|
|
1415
1334
|
self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
|
|
1416
|
-
self.enc_bbox_head = RTDetrV2MLPPredictionHead(config
|
|
1335
|
+
self.enc_bbox_head = RTDetrV2MLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
|
|
1417
1336
|
|
|
1418
1337
|
# init encoder output anchors and valid_mask
|
|
1419
1338
|
if config.anchor_image_size:
|
|
@@ -1423,8 +1342,8 @@ class RTDetrV2Model(RTDetrV2PreTrainedModel):
|
|
|
1423
1342
|
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L412
|
|
1424
1343
|
num_backbone_outs = len(config.decoder_in_channels)
|
|
1425
1344
|
decoder_input_proj_list = []
|
|
1426
|
-
for
|
|
1427
|
-
in_channels = config.decoder_in_channels[
|
|
1345
|
+
for i in range(num_backbone_outs):
|
|
1346
|
+
in_channels = config.decoder_in_channels[i]
|
|
1428
1347
|
decoder_input_proj_list.append(
|
|
1429
1348
|
nn.Sequential(
|
|
1430
1349
|
nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
|
|
@@ -1483,26 +1402,20 @@ class RTDetrV2Model(RTDetrV2PreTrainedModel):
|
|
|
1483
1402
|
return anchors, valid_mask
|
|
1484
1403
|
|
|
1485
1404
|
@auto_docstring
|
|
1405
|
+
@can_return_tuple
|
|
1486
1406
|
def forward(
|
|
1487
1407
|
self,
|
|
1488
1408
|
pixel_values: torch.FloatTensor,
|
|
1489
1409
|
pixel_mask: torch.LongTensor | None = None,
|
|
1490
1410
|
encoder_outputs: torch.FloatTensor | None = None,
|
|
1491
1411
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1492
|
-
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1493
1412
|
labels: list[dict] | None = None,
|
|
1494
|
-
|
|
1495
|
-
output_hidden_states: bool | None = None,
|
|
1496
|
-
return_dict: bool | None = None,
|
|
1497
|
-
**kwargs,
|
|
1413
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1498
1414
|
) -> tuple[torch.FloatTensor] | RTDetrV2ModelOutput:
|
|
1499
1415
|
r"""
|
|
1500
1416
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1501
1417
|
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
1502
1418
|
can choose to directly pass a flattened representation of an image.
|
|
1503
|
-
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1504
|
-
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1505
|
-
embedded representation.
|
|
1506
1419
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1507
1420
|
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
|
1508
1421
|
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
|
@@ -1530,53 +1443,46 @@ class RTDetrV2Model(RTDetrV2PreTrainedModel):
|
|
|
1530
1443
|
>>> list(last_hidden_states.shape)
|
|
1531
1444
|
[1, 300, 256]
|
|
1532
1445
|
```"""
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
|
|
1446
|
+
if pixel_values is None and inputs_embeds is None:
|
|
1447
|
+
raise ValueError("You have to specify either pixel_values or inputs_embeds")
|
|
1448
|
+
|
|
1449
|
+
if inputs_embeds is None:
|
|
1450
|
+
batch_size, num_channels, height, width = pixel_values.shape
|
|
1451
|
+
device = pixel_values.device
|
|
1452
|
+
if pixel_mask is None:
|
|
1453
|
+
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
|
1454
|
+
features = self.backbone(pixel_values, pixel_mask)
|
|
1455
|
+
proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
|
|
1456
|
+
else:
|
|
1457
|
+
batch_size = inputs_embeds.shape[0]
|
|
1458
|
+
device = inputs_embeds.device
|
|
1459
|
+
proj_feats = inputs_embeds
|
|
1548
1460
|
|
|
1549
1461
|
if encoder_outputs is None:
|
|
1550
1462
|
encoder_outputs = self.encoder(
|
|
1551
1463
|
proj_feats,
|
|
1552
|
-
|
|
1553
|
-
output_hidden_states=output_hidden_states,
|
|
1554
|
-
return_dict=return_dict,
|
|
1464
|
+
**kwargs,
|
|
1555
1465
|
)
|
|
1556
|
-
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
|
1557
|
-
elif
|
|
1466
|
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
|
1467
|
+
elif not isinstance(encoder_outputs, BaseModelOutput):
|
|
1558
1468
|
encoder_outputs = BaseModelOutput(
|
|
1559
1469
|
last_hidden_state=encoder_outputs[0],
|
|
1560
|
-
hidden_states=encoder_outputs[1] if
|
|
1561
|
-
attentions=encoder_outputs[2]
|
|
1562
|
-
if len(encoder_outputs) > 2
|
|
1563
|
-
else encoder_outputs[1]
|
|
1564
|
-
if output_attentions
|
|
1565
|
-
else None,
|
|
1470
|
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1471
|
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1566
1472
|
)
|
|
1567
1473
|
|
|
1568
1474
|
# Equivalent to def _get_encoder_input
|
|
1569
1475
|
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L412
|
|
1570
1476
|
sources = []
|
|
1571
|
-
for level, source in enumerate(encoder_outputs
|
|
1477
|
+
for level, source in enumerate(encoder_outputs.last_hidden_state):
|
|
1572
1478
|
sources.append(self.decoder_input_proj[level](source))
|
|
1573
1479
|
|
|
1574
1480
|
# Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
|
|
1575
1481
|
if self.config.num_feature_levels > len(sources):
|
|
1576
1482
|
_len_sources = len(sources)
|
|
1577
|
-
sources.append(self.decoder_input_proj[_len_sources](encoder_outputs
|
|
1483
|
+
sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state)[-1])
|
|
1578
1484
|
for i in range(_len_sources + 1, self.config.num_feature_levels):
|
|
1579
|
-
sources.append(self.decoder_input_proj[i](encoder_outputs[
|
|
1485
|
+
sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1]))
|
|
1580
1486
|
|
|
1581
1487
|
# Prepare encoder inputs (by flattening)
|
|
1582
1488
|
source_flatten = []
|
|
@@ -1668,22 +1574,9 @@ class RTDetrV2Model(RTDetrV2PreTrainedModel):
|
|
|
1668
1574
|
spatial_shapes=spatial_shapes,
|
|
1669
1575
|
spatial_shapes_list=spatial_shapes_list,
|
|
1670
1576
|
level_start_index=level_start_index,
|
|
1671
|
-
|
|
1672
|
-
output_hidden_states=output_hidden_states,
|
|
1673
|
-
return_dict=return_dict,
|
|
1577
|
+
**kwargs,
|
|
1674
1578
|
)
|
|
1675
1579
|
|
|
1676
|
-
if not return_dict:
|
|
1677
|
-
enc_outputs = tuple(
|
|
1678
|
-
value
|
|
1679
|
-
for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits]
|
|
1680
|
-
if value is not None
|
|
1681
|
-
)
|
|
1682
|
-
dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values])
|
|
1683
|
-
tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs
|
|
1684
|
-
|
|
1685
|
-
return tuple_outputs
|
|
1686
|
-
|
|
1687
1580
|
return RTDetrV2ModelOutput(
|
|
1688
1581
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
1689
1582
|
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
|
|
@@ -1706,21 +1599,17 @@ class RTDetrV2Model(RTDetrV2PreTrainedModel):
|
|
|
1706
1599
|
)
|
|
1707
1600
|
|
|
1708
1601
|
|
|
1709
|
-
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
1710
1602
|
class RTDetrV2MLPPredictionHead(nn.Module):
|
|
1711
1603
|
"""
|
|
1712
1604
|
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
|
1713
1605
|
height and width of a bounding box w.r.t. an image.
|
|
1714
1606
|
|
|
1715
|
-
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
1716
|
-
Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_paddle/ppdet/modeling/transformers/utils.py#L453
|
|
1717
|
-
|
|
1718
1607
|
"""
|
|
1719
1608
|
|
|
1720
|
-
def __init__(self,
|
|
1609
|
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
1721
1610
|
super().__init__()
|
|
1722
1611
|
self.num_layers = num_layers
|
|
1723
|
-
h = [
|
|
1612
|
+
h = [hidden_dim] * (num_layers - 1)
|
|
1724
1613
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
|
1725
1614
|
|
|
1726
1615
|
def forward(self, x):
|
|
@@ -1820,8 +1709,8 @@ class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel):
|
|
|
1820
1709
|
_tied_weights_keys = {
|
|
1821
1710
|
r"bbox_embed.(?![0])\d+": r"bbox_embed.0",
|
|
1822
1711
|
r"class_embed.(?![0])\d+": r"^class_embed.0",
|
|
1823
|
-
"
|
|
1824
|
-
"
|
|
1712
|
+
"class_embed": "model.decoder.class_embed",
|
|
1713
|
+
"bbox_embed": "model.decoder.bbox_embed",
|
|
1825
1714
|
}
|
|
1826
1715
|
|
|
1827
1716
|
def __init__(self, config: RTDetrV2Config):
|
|
@@ -1833,7 +1722,7 @@ class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel):
|
|
|
1833
1722
|
)
|
|
1834
1723
|
self.bbox_embed = nn.ModuleList(
|
|
1835
1724
|
[
|
|
1836
|
-
RTDetrV2MLPPredictionHead(config
|
|
1725
|
+
RTDetrV2MLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
|
|
1837
1726
|
for _ in range(config.decoder_layers)
|
|
1838
1727
|
]
|
|
1839
1728
|
)
|
|
@@ -1847,26 +1736,20 @@ class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel):
|
|
|
1847
1736
|
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
|
|
1848
1737
|
|
|
1849
1738
|
@auto_docstring
|
|
1739
|
+
@can_return_tuple
|
|
1850
1740
|
def forward(
|
|
1851
1741
|
self,
|
|
1852
1742
|
pixel_values: torch.FloatTensor,
|
|
1853
1743
|
pixel_mask: torch.LongTensor | None = None,
|
|
1854
1744
|
encoder_outputs: torch.FloatTensor | None = None,
|
|
1855
1745
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1856
|
-
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1857
1746
|
labels: list[dict] | None = None,
|
|
1858
|
-
|
|
1859
|
-
output_hidden_states: bool | None = None,
|
|
1860
|
-
return_dict: bool | None = None,
|
|
1861
|
-
**kwargs,
|
|
1747
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1862
1748
|
) -> tuple[torch.FloatTensor] | RTDetrV2ObjectDetectionOutput:
|
|
1863
1749
|
r"""
|
|
1864
1750
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1865
1751
|
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
1866
1752
|
can choose to directly pass a flattened representation of an image.
|
|
1867
|
-
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1868
|
-
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1869
|
-
embedded representation.
|
|
1870
1753
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1871
1754
|
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
|
1872
1755
|
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
|
@@ -1919,40 +1802,29 @@ class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel):
|
|
|
1919
1802
|
Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48]
|
|
1920
1803
|
Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99]
|
|
1921
1804
|
```"""
|
|
1922
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1923
|
-
output_hidden_states = (
|
|
1924
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1925
|
-
)
|
|
1926
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1927
|
-
|
|
1928
1805
|
outputs = self.model(
|
|
1929
1806
|
pixel_values,
|
|
1930
1807
|
pixel_mask=pixel_mask,
|
|
1931
1808
|
encoder_outputs=encoder_outputs,
|
|
1932
1809
|
inputs_embeds=inputs_embeds,
|
|
1933
|
-
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
1934
1810
|
labels=labels,
|
|
1935
|
-
|
|
1936
|
-
output_hidden_states=output_hidden_states,
|
|
1937
|
-
return_dict=return_dict,
|
|
1811
|
+
**kwargs,
|
|
1938
1812
|
)
|
|
1939
1813
|
|
|
1940
|
-
denoising_meta_values =
|
|
1941
|
-
outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None
|
|
1942
|
-
)
|
|
1814
|
+
denoising_meta_values = outputs.denoising_meta_values if self.training else None
|
|
1943
1815
|
|
|
1944
|
-
outputs_class = outputs.intermediate_logits
|
|
1945
|
-
outputs_coord = outputs.intermediate_reference_points
|
|
1946
|
-
predicted_corners = outputs.intermediate_predicted_corners
|
|
1947
|
-
initial_reference_points = outputs.initial_reference_points
|
|
1816
|
+
outputs_class = outputs.intermediate_logits
|
|
1817
|
+
outputs_coord = outputs.intermediate_reference_points
|
|
1818
|
+
predicted_corners = outputs.intermediate_predicted_corners
|
|
1819
|
+
initial_reference_points = outputs.initial_reference_points
|
|
1948
1820
|
|
|
1949
1821
|
logits = outputs_class[:, -1]
|
|
1950
1822
|
pred_boxes = outputs_coord[:, -1]
|
|
1951
1823
|
|
|
1952
1824
|
loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
|
|
1953
1825
|
if labels is not None:
|
|
1954
|
-
enc_topk_logits = outputs.enc_topk_logits
|
|
1955
|
-
enc_topk_bboxes = outputs.enc_topk_bboxes
|
|
1826
|
+
enc_topk_logits = outputs.enc_topk_logits
|
|
1827
|
+
enc_topk_bboxes = outputs.enc_topk_bboxes
|
|
1956
1828
|
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
|
1957
1829
|
logits,
|
|
1958
1830
|
labels,
|
|
@@ -1969,13 +1841,6 @@ class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel):
|
|
|
1969
1841
|
**kwargs,
|
|
1970
1842
|
)
|
|
1971
1843
|
|
|
1972
|
-
if not return_dict:
|
|
1973
|
-
if auxiliary_outputs is not None:
|
|
1974
|
-
output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs
|
|
1975
|
-
else:
|
|
1976
|
-
output = (logits, pred_boxes) + outputs
|
|
1977
|
-
return ((loss, loss_dict) + output) if loss is not None else output
|
|
1978
|
-
|
|
1979
1844
|
return RTDetrV2ObjectDetectionOutput(
|
|
1980
1845
|
loss=loss,
|
|
1981
1846
|
loss_dict=loss_dict,
|