transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
2
|
+
# This file was automatically generated from src/transformers/models/rt_detr/modular_rt_detr.py.
|
|
3
|
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
4
|
+
# the file from the modular. If any change should be done, please apply the change to the
|
|
5
|
+
# modular_rt_detr.py file directly. One of our CI enforces this.
|
|
6
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
1
7
|
# Copyright 2024 Baidu Inc and The HuggingFace Inc. team.
|
|
2
8
|
#
|
|
3
9
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -11,10 +17,9 @@
|
|
|
11
17
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
18
|
# See the License for the specific language governing permissions and
|
|
13
19
|
# limitations under the License.
|
|
14
|
-
"""PyTorch RT-DETR model."""
|
|
15
|
-
|
|
16
20
|
import math
|
|
17
21
|
import warnings
|
|
22
|
+
from collections.abc import Callable
|
|
18
23
|
from dataclasses import dataclass
|
|
19
24
|
|
|
20
25
|
import torch
|
|
@@ -23,83 +28,18 @@ from torch import Tensor, nn
|
|
|
23
28
|
|
|
24
29
|
from ... import initialization as init
|
|
25
30
|
from ...activations import ACT2CLS, ACT2FN
|
|
31
|
+
from ...backbone_utils import load_backbone
|
|
26
32
|
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
|
27
33
|
from ...integrations import use_kernel_forward_from_hub
|
|
28
34
|
from ...modeling_outputs import BaseModelOutput
|
|
29
|
-
from ...modeling_utils import PreTrainedModel
|
|
35
|
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
36
|
+
from ...processing_utils import Unpack
|
|
30
37
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
31
|
-
from ...utils import
|
|
32
|
-
|
|
33
|
-
auto_docstring,
|
|
34
|
-
logging,
|
|
35
|
-
torch_int,
|
|
36
|
-
)
|
|
37
|
-
from ...utils.backbone_utils import load_backbone
|
|
38
|
+
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int
|
|
39
|
+
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
38
40
|
from .configuration_rt_detr import RTDetrConfig
|
|
39
41
|
|
|
40
42
|
|
|
41
|
-
logger = logging.get_logger(__name__)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
# TODO: Replace all occurrences of the checkpoint with the final one
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
|
48
|
-
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention
|
|
49
|
-
class MultiScaleDeformableAttention(nn.Module):
|
|
50
|
-
def forward(
|
|
51
|
-
self,
|
|
52
|
-
value: Tensor,
|
|
53
|
-
value_spatial_shapes: Tensor,
|
|
54
|
-
value_spatial_shapes_list: list[tuple],
|
|
55
|
-
level_start_index: Tensor,
|
|
56
|
-
sampling_locations: Tensor,
|
|
57
|
-
attention_weights: Tensor,
|
|
58
|
-
im2col_step: int,
|
|
59
|
-
):
|
|
60
|
-
batch_size, _, num_heads, hidden_dim = value.shape
|
|
61
|
-
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
|
62
|
-
value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
|
|
63
|
-
sampling_grids = 2 * sampling_locations - 1
|
|
64
|
-
sampling_value_list = []
|
|
65
|
-
for level_id, (height, width) in enumerate(value_spatial_shapes_list):
|
|
66
|
-
# batch_size, height*width, num_heads, hidden_dim
|
|
67
|
-
# -> batch_size, height*width, num_heads*hidden_dim
|
|
68
|
-
# -> batch_size, num_heads*hidden_dim, height*width
|
|
69
|
-
# -> batch_size*num_heads, hidden_dim, height, width
|
|
70
|
-
value_l_ = (
|
|
71
|
-
value_list[level_id]
|
|
72
|
-
.flatten(2)
|
|
73
|
-
.transpose(1, 2)
|
|
74
|
-
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
|
75
|
-
)
|
|
76
|
-
# batch_size, num_queries, num_heads, num_points, 2
|
|
77
|
-
# -> batch_size, num_heads, num_queries, num_points, 2
|
|
78
|
-
# -> batch_size*num_heads, num_queries, num_points, 2
|
|
79
|
-
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
|
80
|
-
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
|
81
|
-
sampling_value_l_ = nn.functional.grid_sample(
|
|
82
|
-
value_l_,
|
|
83
|
-
sampling_grid_l_,
|
|
84
|
-
mode="bilinear",
|
|
85
|
-
padding_mode="zeros",
|
|
86
|
-
align_corners=False,
|
|
87
|
-
)
|
|
88
|
-
sampling_value_list.append(sampling_value_l_)
|
|
89
|
-
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
|
90
|
-
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
|
91
|
-
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
|
92
|
-
attention_weights = attention_weights.transpose(1, 2).reshape(
|
|
93
|
-
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
|
94
|
-
)
|
|
95
|
-
output = (
|
|
96
|
-
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
|
97
|
-
.sum(-1)
|
|
98
|
-
.view(batch_size, num_heads * hidden_dim, num_queries)
|
|
99
|
-
)
|
|
100
|
-
return output.transpose(1, 2).contiguous()
|
|
101
|
-
|
|
102
|
-
|
|
103
43
|
@dataclass
|
|
104
44
|
@auto_docstring(
|
|
105
45
|
custom_intro="""
|
|
@@ -274,19 +214,23 @@ class RTDetrObjectDetectionOutput(ModelOutput):
|
|
|
274
214
|
denoising_meta_values: dict | None = None
|
|
275
215
|
|
|
276
216
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
217
|
+
class RTDetrMLP(nn.Module):
|
|
218
|
+
def __init__(self, config: RTDetrConfig, hidden_size: int, intermediate_size: int, activation_function: str):
|
|
219
|
+
super().__init__()
|
|
220
|
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
|
221
|
+
self.fc2 = nn.Linear(intermediate_size, hidden_size)
|
|
222
|
+
self.activation_fn = ACT2FN[activation_function]
|
|
223
|
+
self.activation_dropout = config.activation_dropout
|
|
224
|
+
self.dropout = config.dropout
|
|
280
225
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
226
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
227
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
228
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
229
|
+
hidden_states = self.fc2(hidden_states)
|
|
230
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
231
|
+
return hidden_states
|
|
287
232
|
|
|
288
233
|
|
|
289
|
-
# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->RTDetr
|
|
290
234
|
class RTDetrFrozenBatchNorm2d(nn.Module):
|
|
291
235
|
"""
|
|
292
236
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
@@ -326,152 +270,123 @@ class RTDetrFrozenBatchNorm2d(nn.Module):
|
|
|
326
270
|
return x * scale + bias
|
|
327
271
|
|
|
328
272
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
273
|
+
def eager_attention_forward(
|
|
274
|
+
module: nn.Module,
|
|
275
|
+
query: torch.Tensor,
|
|
276
|
+
key: torch.Tensor,
|
|
277
|
+
value: torch.Tensor,
|
|
278
|
+
attention_mask: torch.Tensor | None,
|
|
279
|
+
scaling: float | None = None,
|
|
280
|
+
dropout: float = 0.0,
|
|
281
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
282
|
+
):
|
|
283
|
+
if scaling is None:
|
|
284
|
+
scaling = query.size(-1) ** -0.5
|
|
333
285
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
input model
|
|
337
|
-
"""
|
|
338
|
-
for name, module in model.named_children():
|
|
339
|
-
if isinstance(module, nn.BatchNorm2d):
|
|
340
|
-
new_module = RTDetrFrozenBatchNorm2d(module.num_features)
|
|
286
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
287
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
341
288
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
new_module.running_mean.copy_(module.running_mean)
|
|
346
|
-
new_module.running_var.copy_(module.running_var)
|
|
289
|
+
if attention_mask is not None:
|
|
290
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
291
|
+
attn_weights = attn_weights + attention_mask
|
|
347
292
|
|
|
348
|
-
|
|
293
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
294
|
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
349
295
|
|
|
350
|
-
|
|
351
|
-
|
|
296
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
297
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
352
298
|
|
|
299
|
+
return attn_output, attn_weights
|
|
353
300
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
num_classes,
|
|
357
|
-
num_queries,
|
|
358
|
-
class_embed,
|
|
359
|
-
num_denoising_queries=100,
|
|
360
|
-
label_noise_ratio=0.5,
|
|
361
|
-
box_noise_scale=1.0,
|
|
362
|
-
):
|
|
301
|
+
|
|
302
|
+
class RTDetrSelfAttention(nn.Module):
|
|
363
303
|
"""
|
|
364
|
-
|
|
304
|
+
Multi-headed self-attention from 'Attention Is All You Need' paper.
|
|
365
305
|
|
|
366
|
-
|
|
367
|
-
targets (`list[dict]`):
|
|
368
|
-
The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
|
|
369
|
-
num_classes (`int`):
|
|
370
|
-
Total number of classes in the dataset.
|
|
371
|
-
num_queries (`int`):
|
|
372
|
-
Number of query slots in the transformer.
|
|
373
|
-
class_embed (`callable`):
|
|
374
|
-
A function or a model layer to embed class labels.
|
|
375
|
-
num_denoising_queries (`int`, *optional*, defaults to 100):
|
|
376
|
-
Number of denoising queries.
|
|
377
|
-
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
|
378
|
-
Ratio of noise applied to labels.
|
|
379
|
-
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
|
380
|
-
Scale of noise applied to bounding boxes.
|
|
381
|
-
Returns:
|
|
382
|
-
`tuple` comprising various elements:
|
|
383
|
-
- **input_query_class** (`torch.FloatTensor`) --
|
|
384
|
-
Class queries with applied label noise.
|
|
385
|
-
- **input_query_bbox** (`torch.FloatTensor`) --
|
|
386
|
-
Bounding box queries with applied box noise.
|
|
387
|
-
- **attn_mask** (`torch.FloatTensor`) --
|
|
388
|
-
Attention mask for separating denoising and reconstruction queries.
|
|
389
|
-
- **denoising_meta_values** (`dict`) --
|
|
390
|
-
Metadata including denoising positive indices, number of groups, and split sizes.
|
|
306
|
+
In RT_DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
|
|
391
307
|
"""
|
|
392
308
|
|
|
393
|
-
|
|
394
|
-
|
|
309
|
+
def __init__(
|
|
310
|
+
self,
|
|
311
|
+
config: RTDetrConfig,
|
|
312
|
+
hidden_size: int,
|
|
313
|
+
num_attention_heads: int,
|
|
314
|
+
dropout: float = 0.0,
|
|
315
|
+
bias: bool = True,
|
|
316
|
+
):
|
|
317
|
+
super().__init__()
|
|
318
|
+
self.config = config
|
|
319
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
320
|
+
self.scaling = self.head_dim**-0.5
|
|
321
|
+
self.attention_dropout = dropout
|
|
322
|
+
self.is_causal = False
|
|
395
323
|
|
|
396
|
-
|
|
397
|
-
|
|
324
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
325
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
326
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
327
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
398
328
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
329
|
+
def forward(
|
|
330
|
+
self,
|
|
331
|
+
hidden_states: torch.Tensor,
|
|
332
|
+
attention_mask: torch.Tensor | None = None,
|
|
333
|
+
position_embeddings: torch.Tensor | None = None,
|
|
334
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
335
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
336
|
+
"""
|
|
337
|
+
Position embeddings are added to both queries and keys (but not values).
|
|
338
|
+
"""
|
|
339
|
+
input_shape = hidden_states.shape[:-1]
|
|
340
|
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
402
341
|
|
|
403
|
-
|
|
404
|
-
num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
|
|
405
|
-
# pad gt to max_num of a batch
|
|
406
|
-
batch_size = len(num_ground_truths)
|
|
342
|
+
query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
|
|
407
343
|
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
344
|
+
query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
345
|
+
key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
346
|
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
411
347
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
input_query_class[i, :num_gt] = targets[i]["class_labels"]
|
|
416
|
-
input_query_bbox[i, :num_gt] = targets[i]["boxes"]
|
|
417
|
-
pad_gt_mask[i, :num_gt] = 1
|
|
418
|
-
# each group has positive and negative queries.
|
|
419
|
-
input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
|
|
420
|
-
input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
|
|
421
|
-
pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
|
|
422
|
-
# positive and negative mask
|
|
423
|
-
negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
|
|
424
|
-
negative_gt_mask[:, max_gt_num:] = 1
|
|
425
|
-
negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
|
|
426
|
-
positive_gt_mask = 1 - negative_gt_mask
|
|
427
|
-
# contrastive denoising training positive index
|
|
428
|
-
positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
|
|
429
|
-
denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
|
|
430
|
-
denoise_positive_idx = torch.split(
|
|
431
|
-
denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
|
|
432
|
-
)
|
|
433
|
-
# total denoising queries
|
|
434
|
-
num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
|
|
348
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
349
|
+
self.config._attn_implementation, eager_attention_forward
|
|
350
|
+
)
|
|
435
351
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
352
|
+
attn_output, attn_weights = attention_interface(
|
|
353
|
+
self,
|
|
354
|
+
query_states,
|
|
355
|
+
key_states,
|
|
356
|
+
value_states,
|
|
357
|
+
attention_mask,
|
|
358
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
359
|
+
scaling=self.scaling,
|
|
360
|
+
**kwargs,
|
|
361
|
+
)
|
|
441
362
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
|
|
446
|
-
rand_part = torch.rand_like(input_query_bbox)
|
|
447
|
-
rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
|
|
448
|
-
rand_part *= rand_sign
|
|
449
|
-
known_bbox += rand_part * diff
|
|
450
|
-
known_bbox.clip_(min=0.0, max=1.0)
|
|
451
|
-
input_query_bbox = corners_to_center_format(known_bbox)
|
|
452
|
-
input_query_bbox = inverse_sigmoid(input_query_bbox)
|
|
363
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
364
|
+
attn_output = self.o_proj(attn_output)
|
|
365
|
+
return attn_output, attn_weights
|
|
453
366
|
|
|
454
|
-
input_query_class = class_embed(input_query_class)
|
|
455
367
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
|
|
368
|
+
def replace_batch_norm(model):
|
|
369
|
+
r"""
|
|
370
|
+
Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrFrozenBatchNorm2d`.
|
|
460
371
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
372
|
+
Args:
|
|
373
|
+
model (torch.nn.Module):
|
|
374
|
+
input model
|
|
375
|
+
"""
|
|
376
|
+
for name, module in model.named_children():
|
|
377
|
+
if isinstance(module, nn.BatchNorm2d):
|
|
378
|
+
new_module = RTDetrFrozenBatchNorm2d(module.num_features)
|
|
467
379
|
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
380
|
+
if module.weight.device != torch.device("meta"):
|
|
381
|
+
new_module.weight.copy_(module.weight)
|
|
382
|
+
new_module.bias.copy_(module.bias)
|
|
383
|
+
new_module.running_mean.copy_(module.running_mean)
|
|
384
|
+
new_module.running_var.copy_(module.running_var)
|
|
473
385
|
|
|
474
|
-
|
|
386
|
+
model._modules[name] = new_module
|
|
387
|
+
|
|
388
|
+
if len(list(module.children())) > 0:
|
|
389
|
+
replace_batch_norm(module)
|
|
475
390
|
|
|
476
391
|
|
|
477
392
|
class RTDetrConvEncoder(nn.Module):
|
|
@@ -531,50 +446,46 @@ class RTDetrEncoderLayer(nn.Module):
|
|
|
531
446
|
def __init__(self, config: RTDetrConfig):
|
|
532
447
|
super().__init__()
|
|
533
448
|
self.normalize_before = config.normalize_before
|
|
449
|
+
self.hidden_size = config.encoder_hidden_dim
|
|
534
450
|
|
|
535
451
|
# self-attention
|
|
536
|
-
self.self_attn =
|
|
537
|
-
|
|
538
|
-
|
|
452
|
+
self.self_attn = RTDetrSelfAttention(
|
|
453
|
+
config=config,
|
|
454
|
+
hidden_size=self.hidden_size,
|
|
455
|
+
num_attention_heads=config.num_attention_heads,
|
|
539
456
|
dropout=config.dropout,
|
|
540
457
|
)
|
|
541
|
-
self.self_attn_layer_norm = nn.LayerNorm(
|
|
458
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
542
459
|
self.dropout = config.dropout
|
|
543
|
-
self.
|
|
544
|
-
self.
|
|
545
|
-
self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim)
|
|
546
|
-
self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim)
|
|
547
|
-
self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
|
|
460
|
+
self.mlp = RTDetrMLP(config, self.hidden_size, config.encoder_ffn_dim, config.encoder_activation_function)
|
|
461
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
548
462
|
|
|
549
463
|
def forward(
|
|
550
464
|
self,
|
|
551
465
|
hidden_states: torch.Tensor,
|
|
552
466
|
attention_mask: torch.Tensor,
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
):
|
|
467
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
468
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
469
|
+
) -> torch.Tensor:
|
|
557
470
|
"""
|
|
558
471
|
Args:
|
|
559
|
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len,
|
|
472
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
560
473
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
561
474
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
562
475
|
values.
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
567
|
-
returned tensors for more detail.
|
|
476
|
+
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
477
|
+
Spatial position embeddings (2D positional encodings of image locations), to be added to both
|
|
478
|
+
the queries and keys in self-attention (but not to values).
|
|
568
479
|
"""
|
|
569
480
|
residual = hidden_states
|
|
570
481
|
if self.normalize_before:
|
|
571
482
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
572
483
|
|
|
573
|
-
hidden_states,
|
|
484
|
+
hidden_states, _ = self.self_attn(
|
|
574
485
|
hidden_states=hidden_states,
|
|
575
486
|
attention_mask=attention_mask,
|
|
576
|
-
position_embeddings=
|
|
577
|
-
|
|
487
|
+
position_embeddings=spatial_position_embeddings,
|
|
488
|
+
**kwargs,
|
|
578
489
|
)
|
|
579
490
|
|
|
580
491
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
@@ -586,12 +497,7 @@ class RTDetrEncoderLayer(nn.Module):
|
|
|
586
497
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
587
498
|
residual = hidden_states
|
|
588
499
|
|
|
589
|
-
hidden_states = self.
|
|
590
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
591
|
-
|
|
592
|
-
hidden_states = self.fc2(hidden_states)
|
|
593
|
-
|
|
594
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
500
|
+
hidden_states = self.mlp(hidden_states)
|
|
595
501
|
|
|
596
502
|
hidden_states = residual + hidden_states
|
|
597
503
|
if not self.normalize_before:
|
|
@@ -602,12 +508,7 @@ class RTDetrEncoderLayer(nn.Module):
|
|
|
602
508
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
603
509
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
604
510
|
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
if output_attentions:
|
|
608
|
-
outputs += (attn_weights,)
|
|
609
|
-
|
|
610
|
-
return outputs
|
|
511
|
+
return hidden_states
|
|
611
512
|
|
|
612
513
|
|
|
613
514
|
class RTDetrRepVggBlock(nn.Module):
|
|
@@ -658,7 +559,61 @@ class RTDetrCSPRepLayer(nn.Module):
|
|
|
658
559
|
return self.conv3(hidden_state_1 + hidden_state_2)
|
|
659
560
|
|
|
660
561
|
|
|
661
|
-
|
|
562
|
+
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
|
563
|
+
class MultiScaleDeformableAttention(nn.Module):
|
|
564
|
+
def forward(
|
|
565
|
+
self,
|
|
566
|
+
value: Tensor,
|
|
567
|
+
value_spatial_shapes: Tensor,
|
|
568
|
+
value_spatial_shapes_list: list[tuple],
|
|
569
|
+
level_start_index: Tensor,
|
|
570
|
+
sampling_locations: Tensor,
|
|
571
|
+
attention_weights: Tensor,
|
|
572
|
+
im2col_step: int,
|
|
573
|
+
):
|
|
574
|
+
batch_size, _, num_heads, hidden_dim = value.shape
|
|
575
|
+
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
|
576
|
+
value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
|
|
577
|
+
sampling_grids = 2 * sampling_locations - 1
|
|
578
|
+
sampling_value_list = []
|
|
579
|
+
for level_id, (height, width) in enumerate(value_spatial_shapes_list):
|
|
580
|
+
# batch_size, height*width, num_heads, hidden_dim
|
|
581
|
+
# -> batch_size, height*width, num_heads*hidden_dim
|
|
582
|
+
# -> batch_size, num_heads*hidden_dim, height*width
|
|
583
|
+
# -> batch_size*num_heads, hidden_dim, height, width
|
|
584
|
+
value_l_ = (
|
|
585
|
+
value_list[level_id]
|
|
586
|
+
.flatten(2)
|
|
587
|
+
.transpose(1, 2)
|
|
588
|
+
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
|
589
|
+
)
|
|
590
|
+
# batch_size, num_queries, num_heads, num_points, 2
|
|
591
|
+
# -> batch_size, num_heads, num_queries, num_points, 2
|
|
592
|
+
# -> batch_size*num_heads, num_queries, num_points, 2
|
|
593
|
+
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
|
594
|
+
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
|
595
|
+
sampling_value_l_ = nn.functional.grid_sample(
|
|
596
|
+
value_l_,
|
|
597
|
+
sampling_grid_l_,
|
|
598
|
+
mode="bilinear",
|
|
599
|
+
padding_mode="zeros",
|
|
600
|
+
align_corners=False,
|
|
601
|
+
)
|
|
602
|
+
sampling_value_list.append(sampling_value_l_)
|
|
603
|
+
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
|
604
|
+
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
|
605
|
+
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
|
606
|
+
attention_weights = attention_weights.transpose(1, 2).reshape(
|
|
607
|
+
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
|
608
|
+
)
|
|
609
|
+
output = (
|
|
610
|
+
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
|
611
|
+
.sum(-1)
|
|
612
|
+
.view(batch_size, num_heads * hidden_dim, num_queries)
|
|
613
|
+
)
|
|
614
|
+
return output.transpose(1, 2).contiguous()
|
|
615
|
+
|
|
616
|
+
|
|
662
617
|
class RTDetrMultiscaleDeformableAttention(nn.Module):
|
|
663
618
|
"""
|
|
664
619
|
Multiscale deformable attention as proposed in Deformable DETR.
|
|
@@ -696,9 +651,6 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
|
|
|
696
651
|
|
|
697
652
|
self.disable_custom_kernels = config.disable_custom_kernels
|
|
698
653
|
|
|
699
|
-
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Tensor | None):
|
|
700
|
-
return tensor if position_embeddings is None else tensor + position_embeddings
|
|
701
|
-
|
|
702
654
|
def forward(
|
|
703
655
|
self,
|
|
704
656
|
hidden_states: torch.Tensor,
|
|
@@ -710,19 +662,19 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
|
|
|
710
662
|
spatial_shapes=None,
|
|
711
663
|
spatial_shapes_list=None,
|
|
712
664
|
level_start_index=None,
|
|
713
|
-
|
|
714
|
-
):
|
|
665
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
666
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
715
667
|
# add position embeddings to the hidden states before projecting to queries and keys
|
|
716
668
|
if position_embeddings is not None:
|
|
717
|
-
hidden_states =
|
|
669
|
+
hidden_states = hidden_states + position_embeddings
|
|
718
670
|
|
|
719
671
|
batch_size, num_queries, _ = hidden_states.shape
|
|
720
672
|
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
|
721
673
|
total_elements = sum(height * width for height, width in spatial_shapes_list)
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
674
|
+
torch_compilable_check(
|
|
675
|
+
total_elements == sequence_length,
|
|
676
|
+
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
|
|
677
|
+
)
|
|
726
678
|
|
|
727
679
|
value = self.value_proj(encoder_hidden_states)
|
|
728
680
|
if attention_mask is not None:
|
|
@@ -769,235 +721,218 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
|
|
|
769
721
|
return output, attention_weights
|
|
770
722
|
|
|
771
723
|
|
|
772
|
-
class
|
|
773
|
-
|
|
774
|
-
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
775
|
-
|
|
776
|
-
Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
|
|
777
|
-
"""
|
|
778
|
-
|
|
779
|
-
def __init__(
|
|
780
|
-
self,
|
|
781
|
-
embed_dim: int,
|
|
782
|
-
num_heads: int,
|
|
783
|
-
dropout: float = 0.0,
|
|
784
|
-
bias: bool = True,
|
|
785
|
-
):
|
|
724
|
+
class RTDetrDecoderLayer(nn.Module):
|
|
725
|
+
def __init__(self, config: RTDetrConfig):
|
|
786
726
|
super().__init__()
|
|
787
|
-
self.
|
|
788
|
-
self.num_heads = num_heads
|
|
789
|
-
self.dropout = dropout
|
|
790
|
-
self.head_dim = embed_dim // num_heads
|
|
791
|
-
if self.head_dim * num_heads != self.embed_dim:
|
|
792
|
-
raise ValueError(
|
|
793
|
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
794
|
-
f" {num_heads})."
|
|
795
|
-
)
|
|
796
|
-
self.scaling = self.head_dim**-0.5
|
|
797
|
-
|
|
798
|
-
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
799
|
-
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
800
|
-
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
801
|
-
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
727
|
+
self.hidden_size = config.d_model
|
|
802
728
|
|
|
803
|
-
|
|
804
|
-
|
|
729
|
+
# self-attention
|
|
730
|
+
self.self_attn = RTDetrSelfAttention(
|
|
731
|
+
config=config,
|
|
732
|
+
hidden_size=self.hidden_size,
|
|
733
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
734
|
+
dropout=config.attention_dropout,
|
|
735
|
+
)
|
|
736
|
+
self.dropout = config.dropout
|
|
805
737
|
|
|
806
|
-
|
|
807
|
-
|
|
738
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
739
|
+
# cross-attention
|
|
740
|
+
self.encoder_attn = RTDetrMultiscaleDeformableAttention(
|
|
741
|
+
config,
|
|
742
|
+
num_heads=config.decoder_attention_heads,
|
|
743
|
+
n_points=config.decoder_n_points,
|
|
744
|
+
)
|
|
745
|
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
746
|
+
# feedforward neural networks
|
|
747
|
+
self.mlp = RTDetrMLP(config, self.hidden_size, config.decoder_ffn_dim, config.decoder_activation_function)
|
|
748
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
808
749
|
|
|
809
750
|
def forward(
|
|
810
751
|
self,
|
|
811
752
|
hidden_states: torch.Tensor,
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
753
|
+
object_queries_position_embeddings: torch.Tensor | None = None,
|
|
754
|
+
reference_points=None,
|
|
755
|
+
spatial_shapes=None,
|
|
756
|
+
spatial_shapes_list=None,
|
|
757
|
+
level_start_index=None,
|
|
758
|
+
encoder_hidden_states: torch.Tensor | None = None,
|
|
759
|
+
encoder_attention_mask: torch.Tensor | None = None,
|
|
760
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
761
|
+
) -> torch.Tensor:
|
|
762
|
+
"""
|
|
763
|
+
Args:
|
|
764
|
+
hidden_states (`torch.FloatTensor`):
|
|
765
|
+
Input to the layer of shape `(batch, seq_len, hidden_size)`.
|
|
766
|
+
object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
767
|
+
Position embeddings for the object query slots. These are added to both queries and keys
|
|
768
|
+
in the self-attention layer (not values).
|
|
769
|
+
reference_points (`torch.FloatTensor`, *optional*):
|
|
770
|
+
Reference points.
|
|
771
|
+
spatial_shapes (`torch.LongTensor`, *optional*):
|
|
772
|
+
Spatial shapes.
|
|
773
|
+
level_start_index (`torch.LongTensor`, *optional*):
|
|
774
|
+
Level start index.
|
|
775
|
+
encoder_hidden_states (`torch.FloatTensor`):
|
|
776
|
+
cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
777
|
+
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
778
|
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
779
|
+
values.
|
|
780
|
+
"""
|
|
781
|
+
residual = hidden_states
|
|
817
782
|
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
783
|
+
# Self Attention
|
|
784
|
+
hidden_states, _ = self.self_attn(
|
|
785
|
+
hidden_states=hidden_states,
|
|
786
|
+
attention_mask=encoder_attention_mask,
|
|
787
|
+
position_embeddings=object_queries_position_embeddings,
|
|
788
|
+
**kwargs,
|
|
789
|
+
)
|
|
823
790
|
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size)
|
|
791
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
792
|
+
hidden_states = residual + hidden_states
|
|
793
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
828
794
|
|
|
829
|
-
|
|
830
|
-
query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape)
|
|
831
|
-
key_states = key_states.view(*proj_shape)
|
|
832
|
-
value_states = value_states.view(*proj_shape)
|
|
795
|
+
residual = hidden_states
|
|
833
796
|
|
|
834
|
-
|
|
797
|
+
# Cross-Attention
|
|
798
|
+
hidden_states, _ = self.encoder_attn(
|
|
799
|
+
hidden_states=hidden_states,
|
|
800
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
801
|
+
position_embeddings=object_queries_position_embeddings,
|
|
802
|
+
reference_points=reference_points,
|
|
803
|
+
spatial_shapes=spatial_shapes,
|
|
804
|
+
spatial_shapes_list=spatial_shapes_list,
|
|
805
|
+
level_start_index=level_start_index,
|
|
806
|
+
**kwargs,
|
|
807
|
+
)
|
|
835
808
|
|
|
836
|
-
|
|
809
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
810
|
+
hidden_states = residual + hidden_states
|
|
837
811
|
|
|
838
|
-
|
|
839
|
-
raise ValueError(
|
|
840
|
-
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
|
|
841
|
-
f" {attn_weights.size()}"
|
|
842
|
-
)
|
|
812
|
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
843
813
|
|
|
844
|
-
#
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
814
|
+
# Fully Connected
|
|
815
|
+
residual = hidden_states
|
|
816
|
+
hidden_states = self.mlp(hidden_states)
|
|
817
|
+
hidden_states = residual + hidden_states
|
|
818
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
|
819
|
+
|
|
820
|
+
return hidden_states
|
|
848
821
|
|
|
849
|
-
if attention_mask is not None:
|
|
850
|
-
if attention_mask.size() != (batch_size, 1, target_len, source_len):
|
|
851
|
-
raise ValueError(
|
|
852
|
-
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
|
853
|
-
f" {attention_mask.size()}"
|
|
854
|
-
)
|
|
855
|
-
if attention_mask.dtype == torch.bool:
|
|
856
|
-
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
|
|
857
|
-
attention_mask, -torch.inf
|
|
858
|
-
)
|
|
859
|
-
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
|
860
|
-
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
|
861
|
-
|
|
862
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
863
|
-
|
|
864
|
-
if output_attentions:
|
|
865
|
-
# this operation is a bit awkward, but it's required to
|
|
866
|
-
# make sure that attn_weights keeps its gradient.
|
|
867
|
-
# In order to do so, attn_weights have to reshaped
|
|
868
|
-
# twice and have to be reused in the following
|
|
869
|
-
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
|
870
|
-
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
|
871
|
-
else:
|
|
872
|
-
attn_weights_reshaped = None
|
|
873
822
|
|
|
874
|
-
|
|
823
|
+
class RTDetrSinePositionEmbedding(nn.Module):
|
|
824
|
+
"""
|
|
825
|
+
2D sinusoidal position embedding used in RT-DETR hybrid encoder.
|
|
826
|
+
"""
|
|
875
827
|
|
|
876
|
-
|
|
828
|
+
def __init__(self, embed_dim: int = 256, temperature: int = 10000):
|
|
829
|
+
super().__init__()
|
|
830
|
+
self.embed_dim = embed_dim
|
|
831
|
+
self.temperature = temperature
|
|
877
832
|
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
833
|
+
@compile_compatible_method_lru_cache(maxsize=32)
|
|
834
|
+
def forward(
|
|
835
|
+
self,
|
|
836
|
+
width: int,
|
|
837
|
+
height: int,
|
|
838
|
+
device: torch.device | str,
|
|
839
|
+
dtype: torch.dtype,
|
|
840
|
+
) -> torch.Tensor:
|
|
841
|
+
"""
|
|
842
|
+
Generate 2D sinusoidal position embeddings.
|
|
883
843
|
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
844
|
+
Returns:
|
|
845
|
+
Position embeddings of shape (1, height*width, embed_dim)
|
|
846
|
+
"""
|
|
847
|
+
grid_w = torch.arange(torch_int(width), device=device).to(dtype)
|
|
848
|
+
grid_h = torch.arange(torch_int(height), device=device).to(dtype)
|
|
849
|
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
|
850
|
+
if self.embed_dim % 4 != 0:
|
|
851
|
+
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
|
|
852
|
+
pos_dim = self.embed_dim // 4
|
|
853
|
+
omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
|
|
854
|
+
omega = 1.0 / (self.temperature**omega)
|
|
887
855
|
|
|
888
|
-
|
|
856
|
+
out_w = grid_w.flatten()[..., None] @ omega[None]
|
|
857
|
+
out_h = grid_h.flatten()[..., None] @ omega[None]
|
|
889
858
|
|
|
890
|
-
return
|
|
859
|
+
return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
|
|
891
860
|
|
|
892
861
|
|
|
893
|
-
class
|
|
862
|
+
class RTDetrAIFILayer(nn.Module):
|
|
863
|
+
"""
|
|
864
|
+
AIFI (Attention-based Intra-scale Feature Interaction) layer used in RT-DETR hybrid encoder.
|
|
865
|
+
"""
|
|
866
|
+
|
|
894
867
|
def __init__(self, config: RTDetrConfig):
|
|
895
868
|
super().__init__()
|
|
896
|
-
|
|
897
|
-
self.
|
|
898
|
-
|
|
899
|
-
num_heads=config.decoder_attention_heads,
|
|
900
|
-
dropout=config.attention_dropout,
|
|
901
|
-
)
|
|
902
|
-
self.dropout = config.dropout
|
|
903
|
-
self.activation_fn = ACT2FN[config.decoder_activation_function]
|
|
904
|
-
self.activation_dropout = config.activation_dropout
|
|
869
|
+
self.config = config
|
|
870
|
+
self.encoder_hidden_dim = config.encoder_hidden_dim
|
|
871
|
+
self.eval_size = config.eval_size
|
|
905
872
|
|
|
906
|
-
self.
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
config,
|
|
910
|
-
num_heads=config.decoder_attention_heads,
|
|
911
|
-
n_points=config.decoder_n_points,
|
|
873
|
+
self.position_embedding = RTDetrSinePositionEmbedding(
|
|
874
|
+
embed_dim=self.encoder_hidden_dim,
|
|
875
|
+
temperature=config.positional_encoding_temperature,
|
|
912
876
|
)
|
|
913
|
-
self.
|
|
914
|
-
# feedforward neural networks
|
|
915
|
-
self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
|
|
916
|
-
self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
|
|
917
|
-
self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
|
|
877
|
+
self.layers = nn.ModuleList([RTDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
918
878
|
|
|
919
879
|
def forward(
|
|
920
|
-
self,
|
|
921
|
-
hidden_states: torch.Tensor,
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
spatial_shapes=None,
|
|
925
|
-
spatial_shapes_list=None,
|
|
926
|
-
level_start_index=None,
|
|
927
|
-
encoder_hidden_states: torch.Tensor | None = None,
|
|
928
|
-
encoder_attention_mask: torch.Tensor | None = None,
|
|
929
|
-
output_attentions: bool | None = False,
|
|
930
|
-
):
|
|
880
|
+
self,
|
|
881
|
+
hidden_states: torch.Tensor,
|
|
882
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
883
|
+
) -> torch.Tensor:
|
|
931
884
|
"""
|
|
932
885
|
Args:
|
|
933
|
-
hidden_states (`torch.FloatTensor`):
|
|
934
|
-
|
|
935
|
-
position_embeddings (`torch.FloatTensor`, *optional*):
|
|
936
|
-
Position embeddings that are added to the queries and keys in the self-attention layer.
|
|
937
|
-
reference_points (`torch.FloatTensor`, *optional*):
|
|
938
|
-
Reference points.
|
|
939
|
-
spatial_shapes (`torch.LongTensor`, *optional*):
|
|
940
|
-
Spatial shapes.
|
|
941
|
-
level_start_index (`torch.LongTensor`, *optional*):
|
|
942
|
-
Level start index.
|
|
943
|
-
encoder_hidden_states (`torch.FloatTensor`):
|
|
944
|
-
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
|
945
|
-
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
946
|
-
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
947
|
-
values.
|
|
948
|
-
output_attentions (`bool`, *optional*):
|
|
949
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
950
|
-
returned tensors for more detail.
|
|
886
|
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
|
|
887
|
+
Feature map to process.
|
|
951
888
|
"""
|
|
952
|
-
|
|
889
|
+
batch_size = hidden_states.shape[0]
|
|
890
|
+
height, width = hidden_states.shape[2:]
|
|
953
891
|
|
|
954
|
-
|
|
955
|
-
hidden_states, self_attn_weights = self.self_attn(
|
|
956
|
-
hidden_states=hidden_states,
|
|
957
|
-
attention_mask=encoder_attention_mask,
|
|
958
|
-
position_embeddings=position_embeddings,
|
|
959
|
-
output_attentions=output_attentions,
|
|
960
|
-
)
|
|
892
|
+
hidden_states = hidden_states.flatten(2).permute(0, 2, 1)
|
|
961
893
|
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
894
|
+
if self.training or self.eval_size is None:
|
|
895
|
+
pos_embed = self.position_embedding(
|
|
896
|
+
width=width,
|
|
897
|
+
height=height,
|
|
898
|
+
device=hidden_states.device,
|
|
899
|
+
dtype=hidden_states.dtype,
|
|
900
|
+
)
|
|
901
|
+
else:
|
|
902
|
+
pos_embed = None
|
|
965
903
|
|
|
966
|
-
|
|
904
|
+
for layer in self.layers:
|
|
905
|
+
hidden_states = layer(
|
|
906
|
+
hidden_states,
|
|
907
|
+
attention_mask=None,
|
|
908
|
+
spatial_position_embeddings=pos_embed,
|
|
909
|
+
**kwargs,
|
|
910
|
+
)
|
|
967
911
|
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
hidden_states, cross_attn_weights = self.encoder_attn(
|
|
971
|
-
hidden_states=hidden_states,
|
|
972
|
-
encoder_hidden_states=encoder_hidden_states,
|
|
973
|
-
position_embeddings=position_embeddings,
|
|
974
|
-
reference_points=reference_points,
|
|
975
|
-
spatial_shapes=spatial_shapes,
|
|
976
|
-
spatial_shapes_list=spatial_shapes_list,
|
|
977
|
-
level_start_index=level_start_index,
|
|
978
|
-
output_attentions=output_attentions,
|
|
912
|
+
hidden_states = (
|
|
913
|
+
hidden_states.permute(0, 2, 1).reshape(batch_size, self.encoder_hidden_dim, height, width).contiguous()
|
|
979
914
|
)
|
|
980
915
|
|
|
981
|
-
|
|
982
|
-
hidden_states = second_residual + hidden_states
|
|
916
|
+
return hidden_states
|
|
983
917
|
|
|
984
|
-
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
985
918
|
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
hidden_states = self.fc2(hidden_states)
|
|
991
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
992
|
-
hidden_states = residual + hidden_states
|
|
993
|
-
hidden_states = self.final_layer_norm(hidden_states)
|
|
919
|
+
class RTDetrMLPPredictionHead(nn.Module):
|
|
920
|
+
"""
|
|
921
|
+
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
|
922
|
+
height and width of a bounding box w.r.t. an image.
|
|
994
923
|
|
|
995
|
-
|
|
924
|
+
"""
|
|
996
925
|
|
|
997
|
-
|
|
998
|
-
|
|
926
|
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
927
|
+
super().__init__()
|
|
928
|
+
self.num_layers = num_layers
|
|
929
|
+
h = [hidden_dim] * (num_layers - 1)
|
|
930
|
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
|
999
931
|
|
|
1000
|
-
|
|
932
|
+
def forward(self, x):
|
|
933
|
+
for i, layer in enumerate(self.layers):
|
|
934
|
+
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
935
|
+
return x
|
|
1001
936
|
|
|
1002
937
|
|
|
1003
938
|
@auto_docstring
|
|
@@ -1007,6 +942,10 @@ class RTDetrPreTrainedModel(PreTrainedModel):
|
|
|
1007
942
|
main_input_name = "pixel_values"
|
|
1008
943
|
input_modalities = ("image",)
|
|
1009
944
|
_no_split_modules = [r"RTDetrHybridEncoder", r"RTDetrDecoderLayer"]
|
|
945
|
+
_supports_sdpa = True
|
|
946
|
+
_supports_flash_attn = True
|
|
947
|
+
_supports_attention_backend = True
|
|
948
|
+
_supports_flex_attn = True
|
|
1010
949
|
|
|
1011
950
|
@torch.no_grad()
|
|
1012
951
|
def _init_weights(self, module):
|
|
@@ -1072,35 +1011,23 @@ class RTDetrPreTrainedModel(PreTrainedModel):
|
|
|
1072
1011
|
init.xavier_uniform_(module.denoising_class_embed.weight)
|
|
1073
1012
|
|
|
1074
1013
|
|
|
1075
|
-
class
|
|
1076
|
-
def __init__(self, config: RTDetrConfig):
|
|
1077
|
-
super().__init__()
|
|
1078
|
-
|
|
1079
|
-
self.layers = nn.ModuleList([RTDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
1080
|
-
|
|
1081
|
-
def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor:
|
|
1082
|
-
hidden_states = src
|
|
1083
|
-
for layer in self.layers:
|
|
1084
|
-
hidden_states = layer(
|
|
1085
|
-
hidden_states,
|
|
1086
|
-
attention_mask=src_mask,
|
|
1087
|
-
position_embeddings=pos_embed,
|
|
1088
|
-
output_attentions=output_attentions,
|
|
1089
|
-
)
|
|
1090
|
-
return hidden_states
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
class RTDetrHybridEncoder(nn.Module):
|
|
1014
|
+
class RTDetrHybridEncoder(RTDetrPreTrainedModel):
|
|
1094
1015
|
"""
|
|
1095
|
-
|
|
1096
|
-
(FPN) and a bottom-up Path Aggregation Network (PAN).
|
|
1016
|
+
Hybrid encoder consisting of AIFI (Attention-based Intra-scale Feature Interaction) layers,
|
|
1017
|
+
a top-down Feature Pyramid Network (FPN) and a bottom-up Path Aggregation Network (PAN).
|
|
1018
|
+
More details on the paper: https://huggingface.co/papers/2304.08069
|
|
1097
1019
|
|
|
1098
1020
|
Args:
|
|
1099
1021
|
config: RTDetrConfig
|
|
1100
1022
|
"""
|
|
1101
1023
|
|
|
1024
|
+
_can_record_outputs = {
|
|
1025
|
+
"hidden_states": RTDetrAIFILayer,
|
|
1026
|
+
"attentions": RTDetrSelfAttention,
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1102
1029
|
def __init__(self, config: RTDetrConfig):
|
|
1103
|
-
super().__init__()
|
|
1030
|
+
super().__init__(config)
|
|
1104
1031
|
self.config = config
|
|
1105
1032
|
self.in_channels = config.encoder_in_channels
|
|
1106
1033
|
self.feat_strides = config.feat_strides
|
|
@@ -1112,10 +1039,9 @@ class RTDetrHybridEncoder(nn.Module):
|
|
|
1112
1039
|
self.out_strides = self.feat_strides
|
|
1113
1040
|
self.num_fpn_stages = len(self.in_channels) - 1
|
|
1114
1041
|
self.num_pan_stages = len(self.in_channels) - 1
|
|
1115
|
-
activation = config.activation_function
|
|
1116
1042
|
|
|
1117
|
-
#
|
|
1118
|
-
self.
|
|
1043
|
+
# AIFI (Attention-based Intra-scale Feature Interaction) layers
|
|
1044
|
+
self.aifi = nn.ModuleList([RTDetrAIFILayer(config) for _ in range(len(self.encode_proj_layers))])
|
|
1119
1045
|
|
|
1120
1046
|
# top-down FPN
|
|
1121
1047
|
self.lateral_convs = nn.ModuleList()
|
|
@@ -1127,7 +1053,7 @@ class RTDetrHybridEncoder(nn.Module):
|
|
|
1127
1053
|
out_channels=self.encoder_hidden_dim,
|
|
1128
1054
|
kernel_size=1,
|
|
1129
1055
|
stride=1,
|
|
1130
|
-
activation=
|
|
1056
|
+
activation=config.activation_function,
|
|
1131
1057
|
)
|
|
1132
1058
|
fpn_block = RTDetrCSPRepLayer(config)
|
|
1133
1059
|
self.lateral_convs.append(lateral_conv)
|
|
@@ -1143,118 +1069,36 @@ class RTDetrHybridEncoder(nn.Module):
|
|
|
1143
1069
|
out_channels=self.encoder_hidden_dim,
|
|
1144
1070
|
kernel_size=3,
|
|
1145
1071
|
stride=2,
|
|
1146
|
-
activation=
|
|
1072
|
+
activation=config.activation_function,
|
|
1147
1073
|
)
|
|
1148
1074
|
pan_block = RTDetrCSPRepLayer(config)
|
|
1149
1075
|
self.downsample_convs.append(downsample_conv)
|
|
1150
1076
|
self.pan_blocks.append(pan_block)
|
|
1151
1077
|
|
|
1152
|
-
|
|
1153
|
-
def build_2d_sincos_position_embedding(
|
|
1154
|
-
width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
|
|
1155
|
-
):
|
|
1156
|
-
grid_w = torch.arange(torch_int(width), device=device).to(dtype)
|
|
1157
|
-
grid_h = torch.arange(torch_int(height), device=device).to(dtype)
|
|
1158
|
-
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
|
1159
|
-
if embed_dim % 4 != 0:
|
|
1160
|
-
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
|
|
1161
|
-
pos_dim = embed_dim // 4
|
|
1162
|
-
omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
|
|
1163
|
-
omega = 1.0 / (temperature**omega)
|
|
1164
|
-
|
|
1165
|
-
out_w = grid_w.flatten()[..., None] @ omega[None]
|
|
1166
|
-
out_h = grid_h.flatten()[..., None] @ omega[None]
|
|
1167
|
-
|
|
1168
|
-
return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
|
|
1078
|
+
self.post_init()
|
|
1169
1079
|
|
|
1080
|
+
@check_model_inputs(tie_last_hidden_states=False)
|
|
1170
1081
|
def forward(
|
|
1171
1082
|
self,
|
|
1172
1083
|
inputs_embeds=None,
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
spatial_shapes=None,
|
|
1176
|
-
level_start_index=None,
|
|
1177
|
-
valid_ratios=None,
|
|
1178
|
-
output_attentions=None,
|
|
1179
|
-
output_hidden_states=None,
|
|
1180
|
-
return_dict=None,
|
|
1181
|
-
):
|
|
1084
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1085
|
+
) -> BaseModelOutput:
|
|
1182
1086
|
r"""
|
|
1183
1087
|
Args:
|
|
1184
1088
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
1185
1089
|
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
|
1186
|
-
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1187
|
-
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
|
1188
|
-
- 1 for pixel features that are real (i.e. **not masked**),
|
|
1189
|
-
- 0 for pixel features that are padding (i.e. **masked**).
|
|
1190
|
-
[What are attention masks?](../glossary#attention-mask)
|
|
1191
|
-
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
1192
|
-
Position embeddings that are added to the queries and keys in each self-attention layer.
|
|
1193
|
-
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
|
|
1194
|
-
Spatial shapes of each feature map.
|
|
1195
|
-
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
|
|
1196
|
-
Starting index of each feature map.
|
|
1197
|
-
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
|
|
1198
|
-
Ratio of valid area in each feature level.
|
|
1199
|
-
output_attentions (`bool`, *optional*):
|
|
1200
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
1201
|
-
returned tensors for more detail.
|
|
1202
|
-
output_hidden_states (`bool`, *optional*):
|
|
1203
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
1204
|
-
for more detail.
|
|
1205
|
-
return_dict (`bool`, *optional*):
|
|
1206
|
-
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
1207
1090
|
"""
|
|
1208
|
-
|
|
1209
|
-
output_hidden_states = (
|
|
1210
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1211
|
-
)
|
|
1212
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1213
|
-
|
|
1214
|
-
hidden_states = inputs_embeds
|
|
1091
|
+
feature_maps = inputs_embeds
|
|
1215
1092
|
|
|
1216
|
-
|
|
1217
|
-
all_attentions = () if output_attentions else None
|
|
1218
|
-
|
|
1219
|
-
# encoder
|
|
1093
|
+
# AIFI: Apply transformer encoder to specified feature levels
|
|
1220
1094
|
if self.config.encoder_layers > 0:
|
|
1221
1095
|
for i, enc_ind in enumerate(self.encode_proj_layers):
|
|
1222
|
-
|
|
1223
|
-
encoder_states = encoder_states + (hidden_states[enc_ind],)
|
|
1224
|
-
height, width = hidden_states[enc_ind].shape[2:]
|
|
1225
|
-
# flatten [batch, channel, height, width] to [batch, height*width, channel]
|
|
1226
|
-
src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
|
|
1227
|
-
if self.training or self.eval_size is None:
|
|
1228
|
-
pos_embed = self.build_2d_sincos_position_embedding(
|
|
1229
|
-
width,
|
|
1230
|
-
height,
|
|
1231
|
-
self.encoder_hidden_dim,
|
|
1232
|
-
self.positional_encoding_temperature,
|
|
1233
|
-
device=src_flatten.device,
|
|
1234
|
-
dtype=src_flatten.dtype,
|
|
1235
|
-
)
|
|
1236
|
-
else:
|
|
1237
|
-
pos_embed = None
|
|
1238
|
-
|
|
1239
|
-
layer_outputs = self.encoder[i](
|
|
1240
|
-
src_flatten,
|
|
1241
|
-
pos_embed=pos_embed,
|
|
1242
|
-
output_attentions=output_attentions,
|
|
1243
|
-
)
|
|
1244
|
-
hidden_states[enc_ind] = (
|
|
1245
|
-
layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
|
|
1246
|
-
)
|
|
1247
|
-
|
|
1248
|
-
if output_attentions:
|
|
1249
|
-
all_attentions = all_attentions + (layer_outputs[1],)
|
|
1250
|
-
|
|
1251
|
-
if output_hidden_states:
|
|
1252
|
-
encoder_states = encoder_states + (hidden_states[enc_ind],)
|
|
1096
|
+
feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs)
|
|
1253
1097
|
|
|
1254
1098
|
# top-down FPN
|
|
1255
|
-
fpn_feature_maps = [
|
|
1099
|
+
fpn_feature_maps = [feature_maps[-1]]
|
|
1256
1100
|
for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
|
|
1257
|
-
backbone_feature_map =
|
|
1101
|
+
backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1]
|
|
1258
1102
|
top_fpn_feature_map = fpn_feature_maps[-1]
|
|
1259
1103
|
# apply lateral block
|
|
1260
1104
|
top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
|
|
@@ -1277,20 +1121,29 @@ class RTDetrHybridEncoder(nn.Module):
|
|
|
1277
1121
|
new_pan_feature_map = pan_block(fused_feature_map)
|
|
1278
1122
|
pan_feature_maps.append(new_pan_feature_map)
|
|
1279
1123
|
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1124
|
+
return BaseModelOutput(last_hidden_state=pan_feature_maps)
|
|
1125
|
+
|
|
1126
|
+
|
|
1127
|
+
def inverse_sigmoid(x, eps=1e-5):
|
|
1128
|
+
x = x.clamp(min=0, max=1)
|
|
1129
|
+
x1 = x.clamp(min=eps)
|
|
1130
|
+
x2 = (1 - x).clamp(min=eps)
|
|
1131
|
+
return torch.log(x1 / x2)
|
|
1285
1132
|
|
|
1286
1133
|
|
|
1287
1134
|
class RTDetrDecoder(RTDetrPreTrainedModel):
|
|
1135
|
+
_can_record_outputs = {
|
|
1136
|
+
"hidden_states": RTDetrDecoderLayer,
|
|
1137
|
+
"attentions": RTDetrSelfAttention,
|
|
1138
|
+
"cross_attentions": RTDetrMultiscaleDeformableAttention,
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1288
1141
|
def __init__(self, config: RTDetrConfig):
|
|
1289
1142
|
super().__init__(config)
|
|
1290
1143
|
|
|
1291
1144
|
self.dropout = config.dropout
|
|
1292
1145
|
self.layers = nn.ModuleList([RTDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
|
|
1293
|
-
self.query_pos_head = RTDetrMLPPredictionHead(
|
|
1146
|
+
self.query_pos_head = RTDetrMLPPredictionHead(4, 2 * config.d_model, config.d_model, num_layers=2)
|
|
1294
1147
|
|
|
1295
1148
|
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
|
|
1296
1149
|
self.bbox_embed = None
|
|
@@ -1299,21 +1152,17 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
|
|
|
1299
1152
|
# Initialize weights and apply final processing
|
|
1300
1153
|
self.post_init()
|
|
1301
1154
|
|
|
1155
|
+
@check_model_inputs()
|
|
1302
1156
|
def forward(
|
|
1303
1157
|
self,
|
|
1304
1158
|
inputs_embeds=None,
|
|
1305
1159
|
encoder_hidden_states=None,
|
|
1306
1160
|
encoder_attention_mask=None,
|
|
1307
|
-
position_embeddings=None,
|
|
1308
1161
|
reference_points=None,
|
|
1309
1162
|
spatial_shapes=None,
|
|
1310
1163
|
spatial_shapes_list=None,
|
|
1311
1164
|
level_start_index=None,
|
|
1312
|
-
|
|
1313
|
-
output_attentions=None,
|
|
1314
|
-
output_hidden_states=None,
|
|
1315
|
-
return_dict=None,
|
|
1316
|
-
**kwargs,
|
|
1165
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1317
1166
|
):
|
|
1318
1167
|
r"""
|
|
1319
1168
|
Args:
|
|
@@ -1327,39 +1176,17 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
|
|
|
1327
1176
|
in `[0, 1]`:
|
|
1328
1177
|
- 1 for pixels that are real (i.e. **not masked**),
|
|
1329
1178
|
- 0 for pixels that are padding (i.e. **masked**).
|
|
1330
|
-
position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1331
|
-
Position embeddings that are added to the queries and keys in each self-attention layer.
|
|
1332
1179
|
reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
|
|
1333
1180
|
Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
|
|
1334
1181
|
spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
|
|
1335
1182
|
Spatial shapes of the feature maps.
|
|
1336
1183
|
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
|
|
1337
1184
|
Indexes for the start of each feature level. In range `[0, sequence_length]`.
|
|
1338
|
-
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
|
|
1339
|
-
Ratio of valid area in each feature level.
|
|
1340
|
-
|
|
1341
|
-
output_attentions (`bool`, *optional*):
|
|
1342
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
1343
|
-
returned tensors for more detail.
|
|
1344
|
-
output_hidden_states (`bool`, *optional*):
|
|
1345
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
1346
|
-
for more detail.
|
|
1347
|
-
return_dict (`bool`, *optional*):
|
|
1348
|
-
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
1349
1185
|
"""
|
|
1350
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1351
|
-
output_hidden_states = (
|
|
1352
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1353
|
-
)
|
|
1354
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1355
|
-
|
|
1356
1186
|
if inputs_embeds is not None:
|
|
1357
1187
|
hidden_states = inputs_embeds
|
|
1358
1188
|
|
|
1359
1189
|
# decoder layers
|
|
1360
|
-
all_hidden_states = () if output_hidden_states else None
|
|
1361
|
-
all_self_attns = () if output_attentions else None
|
|
1362
|
-
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
1363
1190
|
intermediate = ()
|
|
1364
1191
|
intermediate_reference_points = ()
|
|
1365
1192
|
intermediate_logits = ()
|
|
@@ -1369,25 +1196,20 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
|
|
|
1369
1196
|
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L252
|
|
1370
1197
|
for idx, decoder_layer in enumerate(self.layers):
|
|
1371
1198
|
reference_points_input = reference_points.unsqueeze(2)
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
if output_hidden_states:
|
|
1375
|
-
all_hidden_states += (hidden_states,)
|
|
1199
|
+
object_queries_position_embeddings = self.query_pos_head(reference_points)
|
|
1376
1200
|
|
|
1377
|
-
|
|
1201
|
+
hidden_states = decoder_layer(
|
|
1378
1202
|
hidden_states,
|
|
1379
|
-
|
|
1203
|
+
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
1380
1204
|
encoder_hidden_states=encoder_hidden_states,
|
|
1381
1205
|
reference_points=reference_points_input,
|
|
1382
1206
|
spatial_shapes=spatial_shapes,
|
|
1383
1207
|
spatial_shapes_list=spatial_shapes_list,
|
|
1384
1208
|
level_start_index=level_start_index,
|
|
1385
1209
|
encoder_attention_mask=encoder_attention_mask,
|
|
1386
|
-
|
|
1210
|
+
**kwargs,
|
|
1387
1211
|
)
|
|
1388
1212
|
|
|
1389
|
-
hidden_states = layer_outputs[0]
|
|
1390
|
-
|
|
1391
1213
|
# hack implementation for iterative bounding box refinement
|
|
1392
1214
|
if self.bbox_embed is not None:
|
|
1393
1215
|
predicted_corners = self.bbox_embed[idx](hidden_states)
|
|
@@ -1403,68 +1225,141 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
|
|
|
1403
1225
|
logits = self.class_embed[idx](hidden_states)
|
|
1404
1226
|
intermediate_logits += (logits,)
|
|
1405
1227
|
|
|
1406
|
-
if output_attentions:
|
|
1407
|
-
all_self_attns += (layer_outputs[1],)
|
|
1408
|
-
|
|
1409
|
-
if encoder_hidden_states is not None:
|
|
1410
|
-
all_cross_attentions += (layer_outputs[2],)
|
|
1411
|
-
|
|
1412
1228
|
# Keep batch_size as first dimension
|
|
1413
1229
|
intermediate = torch.stack(intermediate, dim=1)
|
|
1414
1230
|
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
|
|
1415
1231
|
if self.class_embed is not None:
|
|
1416
1232
|
intermediate_logits = torch.stack(intermediate_logits, dim=1)
|
|
1417
1233
|
|
|
1418
|
-
# add hidden states from the last decoder layer
|
|
1419
|
-
if output_hidden_states:
|
|
1420
|
-
all_hidden_states += (hidden_states,)
|
|
1421
|
-
|
|
1422
|
-
if not return_dict:
|
|
1423
|
-
return tuple(
|
|
1424
|
-
v
|
|
1425
|
-
for v in [
|
|
1426
|
-
hidden_states,
|
|
1427
|
-
intermediate,
|
|
1428
|
-
intermediate_logits,
|
|
1429
|
-
intermediate_reference_points,
|
|
1430
|
-
all_hidden_states,
|
|
1431
|
-
all_self_attns,
|
|
1432
|
-
all_cross_attentions,
|
|
1433
|
-
]
|
|
1434
|
-
if v is not None
|
|
1435
|
-
)
|
|
1436
1234
|
return RTDetrDecoderOutput(
|
|
1437
1235
|
last_hidden_state=hidden_states,
|
|
1438
1236
|
intermediate_hidden_states=intermediate,
|
|
1439
1237
|
intermediate_logits=intermediate_logits,
|
|
1440
1238
|
intermediate_reference_points=intermediate_reference_points,
|
|
1441
|
-
hidden_states=all_hidden_states,
|
|
1442
|
-
attentions=all_self_attns,
|
|
1443
|
-
cross_attentions=all_cross_attentions,
|
|
1444
1239
|
)
|
|
1445
1240
|
|
|
1446
1241
|
|
|
1447
|
-
|
|
1448
|
-
|
|
1242
|
+
def get_contrastive_denoising_training_group(
|
|
1243
|
+
targets,
|
|
1244
|
+
num_classes,
|
|
1245
|
+
num_queries,
|
|
1246
|
+
class_embed,
|
|
1247
|
+
num_denoising_queries=100,
|
|
1248
|
+
label_noise_ratio=0.5,
|
|
1249
|
+
box_noise_scale=1.0,
|
|
1250
|
+
):
|
|
1449
1251
|
"""
|
|
1450
|
-
|
|
1451
|
-
height and width of a bounding box w.r.t. an image.
|
|
1452
|
-
|
|
1453
|
-
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
1454
|
-
Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_paddle/ppdet/modeling/transformers/utils.py#L453
|
|
1252
|
+
Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
|
|
1455
1253
|
|
|
1254
|
+
Args:
|
|
1255
|
+
targets (`list[dict]`):
|
|
1256
|
+
The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
|
|
1257
|
+
num_classes (`int`):
|
|
1258
|
+
Total number of classes in the dataset.
|
|
1259
|
+
num_queries (`int`):
|
|
1260
|
+
Number of query slots in the transformer.
|
|
1261
|
+
class_embed (`callable`):
|
|
1262
|
+
A function or a model layer to embed class labels.
|
|
1263
|
+
num_denoising_queries (`int`, *optional*, defaults to 100):
|
|
1264
|
+
Number of denoising queries.
|
|
1265
|
+
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
|
1266
|
+
Ratio of noise applied to labels.
|
|
1267
|
+
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
|
1268
|
+
Scale of noise applied to bounding boxes.
|
|
1269
|
+
Returns:
|
|
1270
|
+
`tuple` comprising various elements:
|
|
1271
|
+
- **input_query_class** (`torch.FloatTensor`) --
|
|
1272
|
+
Class queries with applied label noise.
|
|
1273
|
+
- **input_query_bbox** (`torch.FloatTensor`) --
|
|
1274
|
+
Bounding box queries with applied box noise.
|
|
1275
|
+
- **attn_mask** (`torch.FloatTensor`) --
|
|
1276
|
+
Attention mask for separating denoising and reconstruction queries.
|
|
1277
|
+
- **denoising_meta_values** (`dict`) --
|
|
1278
|
+
Metadata including denoising positive indices, number of groups, and split sizes.
|
|
1456
1279
|
"""
|
|
1457
1280
|
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
self.num_layers = num_layers
|
|
1461
|
-
h = [d_model] * (num_layers - 1)
|
|
1462
|
-
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
|
1281
|
+
if num_denoising_queries <= 0:
|
|
1282
|
+
return None, None, None, None
|
|
1463
1283
|
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1284
|
+
num_ground_truths = [len(t["class_labels"]) for t in targets]
|
|
1285
|
+
device = targets[0]["class_labels"].device
|
|
1286
|
+
|
|
1287
|
+
max_gt_num = max(num_ground_truths)
|
|
1288
|
+
if max_gt_num == 0:
|
|
1289
|
+
return None, None, None, None
|
|
1290
|
+
|
|
1291
|
+
num_groups_denoising_queries = num_denoising_queries // max_gt_num
|
|
1292
|
+
num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
|
|
1293
|
+
# pad gt to max_num of a batch
|
|
1294
|
+
batch_size = len(num_ground_truths)
|
|
1295
|
+
|
|
1296
|
+
input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
|
|
1297
|
+
input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
|
|
1298
|
+
pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
|
|
1299
|
+
|
|
1300
|
+
for i in range(batch_size):
|
|
1301
|
+
num_gt = num_ground_truths[i]
|
|
1302
|
+
if num_gt > 0:
|
|
1303
|
+
input_query_class[i, :num_gt] = targets[i]["class_labels"]
|
|
1304
|
+
input_query_bbox[i, :num_gt] = targets[i]["boxes"]
|
|
1305
|
+
pad_gt_mask[i, :num_gt] = 1
|
|
1306
|
+
# each group has positive and negative queries.
|
|
1307
|
+
input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
|
|
1308
|
+
input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
|
|
1309
|
+
pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
|
|
1310
|
+
# positive and negative mask
|
|
1311
|
+
negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
|
|
1312
|
+
negative_gt_mask[:, max_gt_num:] = 1
|
|
1313
|
+
negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
|
|
1314
|
+
positive_gt_mask = 1 - negative_gt_mask
|
|
1315
|
+
# contrastive denoising training positive index
|
|
1316
|
+
positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
|
|
1317
|
+
denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
|
|
1318
|
+
denoise_positive_idx = torch.split(
|
|
1319
|
+
denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
|
|
1320
|
+
)
|
|
1321
|
+
# total denoising queries
|
|
1322
|
+
num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
|
|
1323
|
+
|
|
1324
|
+
if label_noise_ratio > 0:
|
|
1325
|
+
mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
|
|
1326
|
+
# randomly put a new one here
|
|
1327
|
+
new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
|
|
1328
|
+
input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
|
|
1329
|
+
|
|
1330
|
+
if box_noise_scale > 0:
|
|
1331
|
+
known_bbox = center_to_corners_format(input_query_bbox)
|
|
1332
|
+
diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
|
|
1333
|
+
rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
|
|
1334
|
+
rand_part = torch.rand_like(input_query_bbox)
|
|
1335
|
+
rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
|
|
1336
|
+
rand_part *= rand_sign
|
|
1337
|
+
known_bbox += rand_part * diff
|
|
1338
|
+
known_bbox.clip_(min=0.0, max=1.0)
|
|
1339
|
+
input_query_bbox = corners_to_center_format(known_bbox)
|
|
1340
|
+
input_query_bbox = inverse_sigmoid(input_query_bbox)
|
|
1341
|
+
|
|
1342
|
+
input_query_class = class_embed(input_query_class)
|
|
1343
|
+
|
|
1344
|
+
target_size = num_denoising_queries + num_queries
|
|
1345
|
+
attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
|
|
1346
|
+
# match query cannot see the reconstruction
|
|
1347
|
+
attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
|
|
1348
|
+
|
|
1349
|
+
# reconstructions cannot see each other
|
|
1350
|
+
for i in range(num_groups_denoising_queries):
|
|
1351
|
+
idx_block_start = max_gt_num * 2 * i
|
|
1352
|
+
idx_block_end = max_gt_num * 2 * (i + 1)
|
|
1353
|
+
attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
|
|
1354
|
+
attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
|
|
1355
|
+
|
|
1356
|
+
denoising_meta_values = {
|
|
1357
|
+
"dn_positive_idx": denoise_positive_idx,
|
|
1358
|
+
"dn_num_group": num_groups_denoising_queries,
|
|
1359
|
+
"dn_num_split": [num_denoising_queries, num_queries],
|
|
1360
|
+
}
|
|
1361
|
+
|
|
1362
|
+
return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
|
|
1468
1363
|
|
|
1469
1364
|
|
|
1470
1365
|
@auto_docstring(
|
|
@@ -1484,8 +1379,8 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
|
|
1484
1379
|
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py#L212
|
|
1485
1380
|
num_backbone_outs = len(intermediate_channel_sizes)
|
|
1486
1381
|
encoder_input_proj_list = []
|
|
1487
|
-
for
|
|
1488
|
-
in_channels = intermediate_channel_sizes[
|
|
1382
|
+
for i in range(num_backbone_outs):
|
|
1383
|
+
in_channels = intermediate_channel_sizes[i]
|
|
1489
1384
|
encoder_input_proj_list.append(
|
|
1490
1385
|
nn.Sequential(
|
|
1491
1386
|
nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
|
|
@@ -1513,7 +1408,7 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
|
|
1513
1408
|
nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
|
|
1514
1409
|
)
|
|
1515
1410
|
self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
|
|
1516
|
-
self.enc_bbox_head = RTDetrMLPPredictionHead(config
|
|
1411
|
+
self.enc_bbox_head = RTDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
|
|
1517
1412
|
|
|
1518
1413
|
# init encoder output anchors and valid_mask
|
|
1519
1414
|
if config.anchor_image_size:
|
|
@@ -1523,8 +1418,8 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
|
|
1523
1418
|
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
|
|
1524
1419
|
num_backbone_outs = len(config.decoder_in_channels)
|
|
1525
1420
|
decoder_input_proj_list = []
|
|
1526
|
-
for
|
|
1527
|
-
in_channels = config.decoder_in_channels[
|
|
1421
|
+
for i in range(num_backbone_outs):
|
|
1422
|
+
in_channels = config.decoder_in_channels[i]
|
|
1528
1423
|
decoder_input_proj_list.append(
|
|
1529
1424
|
nn.Sequential(
|
|
1530
1425
|
nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
|
|
@@ -1584,26 +1479,20 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
|
|
1584
1479
|
return anchors, valid_mask
|
|
1585
1480
|
|
|
1586
1481
|
@auto_docstring
|
|
1482
|
+
@can_return_tuple
|
|
1587
1483
|
def forward(
|
|
1588
1484
|
self,
|
|
1589
1485
|
pixel_values: torch.FloatTensor,
|
|
1590
1486
|
pixel_mask: torch.LongTensor | None = None,
|
|
1591
1487
|
encoder_outputs: torch.FloatTensor | None = None,
|
|
1592
1488
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1593
|
-
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1594
1489
|
labels: list[dict] | None = None,
|
|
1595
|
-
|
|
1596
|
-
output_hidden_states: bool | None = None,
|
|
1597
|
-
return_dict: bool | None = None,
|
|
1598
|
-
**kwargs,
|
|
1490
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1599
1491
|
) -> tuple[torch.FloatTensor] | RTDetrModelOutput:
|
|
1600
1492
|
r"""
|
|
1601
1493
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1602
1494
|
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
1603
1495
|
can choose to directly pass a flattened representation of an image.
|
|
1604
|
-
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1605
|
-
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1606
|
-
embedded representation.
|
|
1607
1496
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1608
1497
|
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
|
1609
1498
|
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
|
@@ -1631,53 +1520,46 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
|
|
1631
1520
|
>>> list(last_hidden_states.shape)
|
|
1632
1521
|
[1, 300, 256]
|
|
1633
1522
|
```"""
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
|
|
1648
|
-
proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
|
|
1523
|
+
if pixel_values is None and inputs_embeds is None:
|
|
1524
|
+
raise ValueError("You have to specify either pixel_values or inputs_embeds")
|
|
1525
|
+
|
|
1526
|
+
if inputs_embeds is None:
|
|
1527
|
+
batch_size, num_channels, height, width = pixel_values.shape
|
|
1528
|
+
device = pixel_values.device
|
|
1529
|
+
if pixel_mask is None:
|
|
1530
|
+
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
|
1531
|
+
features = self.backbone(pixel_values, pixel_mask)
|
|
1532
|
+
proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
|
|
1533
|
+
else:
|
|
1534
|
+
batch_size = inputs_embeds.shape[0]
|
|
1535
|
+
device = inputs_embeds.device
|
|
1536
|
+
proj_feats = inputs_embeds
|
|
1649
1537
|
|
|
1650
1538
|
if encoder_outputs is None:
|
|
1651
1539
|
encoder_outputs = self.encoder(
|
|
1652
1540
|
proj_feats,
|
|
1653
|
-
|
|
1654
|
-
output_hidden_states=output_hidden_states,
|
|
1655
|
-
return_dict=return_dict,
|
|
1541
|
+
**kwargs,
|
|
1656
1542
|
)
|
|
1657
|
-
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
|
1658
|
-
elif
|
|
1543
|
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
|
1544
|
+
elif not isinstance(encoder_outputs, BaseModelOutput):
|
|
1659
1545
|
encoder_outputs = BaseModelOutput(
|
|
1660
1546
|
last_hidden_state=encoder_outputs[0],
|
|
1661
|
-
hidden_states=encoder_outputs[1] if
|
|
1662
|
-
attentions=encoder_outputs[2]
|
|
1663
|
-
if len(encoder_outputs) > 2
|
|
1664
|
-
else encoder_outputs[1]
|
|
1665
|
-
if output_attentions
|
|
1666
|
-
else None,
|
|
1547
|
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1548
|
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1667
1549
|
)
|
|
1668
1550
|
|
|
1669
1551
|
# Equivalent to def _get_encoder_input
|
|
1670
1552
|
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
|
|
1671
1553
|
sources = []
|
|
1672
|
-
for level, source in enumerate(encoder_outputs
|
|
1554
|
+
for level, source in enumerate(encoder_outputs.last_hidden_state):
|
|
1673
1555
|
sources.append(self.decoder_input_proj[level](source))
|
|
1674
1556
|
|
|
1675
1557
|
# Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
|
|
1676
1558
|
if self.config.num_feature_levels > len(sources):
|
|
1677
1559
|
_len_sources = len(sources)
|
|
1678
|
-
sources.append(self.decoder_input_proj[_len_sources](encoder_outputs
|
|
1560
|
+
sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state)[-1])
|
|
1679
1561
|
for i in range(_len_sources + 1, self.config.num_feature_levels):
|
|
1680
|
-
sources.append(self.decoder_input_proj[i](encoder_outputs[
|
|
1562
|
+
sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1]))
|
|
1681
1563
|
|
|
1682
1564
|
# Prepare encoder inputs (by flattening)
|
|
1683
1565
|
source_flatten = []
|
|
@@ -1769,22 +1651,9 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
|
|
1769
1651
|
spatial_shapes=spatial_shapes,
|
|
1770
1652
|
spatial_shapes_list=spatial_shapes_list,
|
|
1771
1653
|
level_start_index=level_start_index,
|
|
1772
|
-
|
|
1773
|
-
output_hidden_states=output_hidden_states,
|
|
1774
|
-
return_dict=return_dict,
|
|
1654
|
+
**kwargs,
|
|
1775
1655
|
)
|
|
1776
1656
|
|
|
1777
|
-
if not return_dict:
|
|
1778
|
-
enc_outputs = tuple(
|
|
1779
|
-
value
|
|
1780
|
-
for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits]
|
|
1781
|
-
if value is not None
|
|
1782
|
-
)
|
|
1783
|
-
dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values])
|
|
1784
|
-
tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs
|
|
1785
|
-
|
|
1786
|
-
return tuple_outputs
|
|
1787
|
-
|
|
1788
1657
|
return RTDetrModelOutput(
|
|
1789
1658
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
1790
1659
|
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
|
|
@@ -1826,7 +1695,7 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
|
|
|
1826
1695
|
[torch.nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]
|
|
1827
1696
|
)
|
|
1828
1697
|
self.model.decoder.bbox_embed = nn.ModuleList(
|
|
1829
|
-
[RTDetrMLPPredictionHead(config
|
|
1698
|
+
[RTDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) for _ in range(num_pred)]
|
|
1830
1699
|
)
|
|
1831
1700
|
# if two-stage, the last class_embed and bbox_embed is for region proposal generation
|
|
1832
1701
|
self.post_init()
|
|
@@ -1835,26 +1704,20 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
|
|
|
1835
1704
|
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
|
|
1836
1705
|
|
|
1837
1706
|
@auto_docstring
|
|
1707
|
+
@can_return_tuple
|
|
1838
1708
|
def forward(
|
|
1839
1709
|
self,
|
|
1840
1710
|
pixel_values: torch.FloatTensor,
|
|
1841
1711
|
pixel_mask: torch.LongTensor | None = None,
|
|
1842
1712
|
encoder_outputs: torch.FloatTensor | None = None,
|
|
1843
1713
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1844
|
-
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1845
1714
|
labels: list[dict] | None = None,
|
|
1846
|
-
|
|
1847
|
-
output_hidden_states: bool | None = None,
|
|
1848
|
-
return_dict: bool | None = None,
|
|
1849
|
-
**kwargs,
|
|
1715
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1850
1716
|
) -> tuple[torch.FloatTensor] | RTDetrObjectDetectionOutput:
|
|
1851
1717
|
r"""
|
|
1852
1718
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1853
1719
|
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
1854
1720
|
can choose to directly pass a flattened representation of an image.
|
|
1855
|
-
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1856
|
-
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1857
|
-
embedded representation.
|
|
1858
1721
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1859
1722
|
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
|
1860
1723
|
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
|
@@ -1907,40 +1770,29 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
|
|
|
1907
1770
|
Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48]
|
|
1908
1771
|
Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99]
|
|
1909
1772
|
```"""
|
|
1910
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1911
|
-
output_hidden_states = (
|
|
1912
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1913
|
-
)
|
|
1914
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1915
|
-
|
|
1916
1773
|
outputs = self.model(
|
|
1917
1774
|
pixel_values,
|
|
1918
1775
|
pixel_mask=pixel_mask,
|
|
1919
1776
|
encoder_outputs=encoder_outputs,
|
|
1920
1777
|
inputs_embeds=inputs_embeds,
|
|
1921
|
-
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
1922
1778
|
labels=labels,
|
|
1923
|
-
|
|
1924
|
-
output_hidden_states=output_hidden_states,
|
|
1925
|
-
return_dict=return_dict,
|
|
1779
|
+
**kwargs,
|
|
1926
1780
|
)
|
|
1927
1781
|
|
|
1928
|
-
denoising_meta_values =
|
|
1929
|
-
outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None
|
|
1930
|
-
)
|
|
1782
|
+
denoising_meta_values = outputs.denoising_meta_values if self.training else None
|
|
1931
1783
|
|
|
1932
|
-
outputs_class = outputs.intermediate_logits
|
|
1933
|
-
outputs_coord = outputs.intermediate_reference_points
|
|
1934
|
-
predicted_corners = outputs.intermediate_predicted_corners
|
|
1935
|
-
initial_reference_points = outputs.initial_reference_points
|
|
1784
|
+
outputs_class = outputs.intermediate_logits
|
|
1785
|
+
outputs_coord = outputs.intermediate_reference_points
|
|
1786
|
+
predicted_corners = outputs.intermediate_predicted_corners
|
|
1787
|
+
initial_reference_points = outputs.initial_reference_points
|
|
1936
1788
|
|
|
1937
1789
|
logits = outputs_class[:, -1]
|
|
1938
1790
|
pred_boxes = outputs_coord[:, -1]
|
|
1939
1791
|
|
|
1940
1792
|
loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
|
|
1941
1793
|
if labels is not None:
|
|
1942
|
-
enc_topk_logits = outputs.enc_topk_logits
|
|
1943
|
-
enc_topk_bboxes = outputs.enc_topk_bboxes
|
|
1794
|
+
enc_topk_logits = outputs.enc_topk_logits
|
|
1795
|
+
enc_topk_bboxes = outputs.enc_topk_bboxes
|
|
1944
1796
|
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
|
1945
1797
|
logits,
|
|
1946
1798
|
labels,
|
|
@@ -1957,13 +1809,6 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
|
|
|
1957
1809
|
**kwargs,
|
|
1958
1810
|
)
|
|
1959
1811
|
|
|
1960
|
-
if not return_dict:
|
|
1961
|
-
if auxiliary_outputs is not None:
|
|
1962
|
-
output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs
|
|
1963
|
-
else:
|
|
1964
|
-
output = (logits, pred_boxes) + outputs
|
|
1965
|
-
return ((loss, loss_dict) + output) if loss is not None else output
|
|
1966
|
-
|
|
1967
1812
|
return RTDetrObjectDetectionOutput(
|
|
1968
1813
|
loss=loss,
|
|
1969
1814
|
loss_dict=loss_dict,
|
|
@@ -1991,8 +1836,4 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
|
|
|
1991
1836
|
)
|
|
1992
1837
|
|
|
1993
1838
|
|
|
1994
|
-
__all__ = [
|
|
1995
|
-
"RTDetrForObjectDetection",
|
|
1996
|
-
"RTDetrModel",
|
|
1997
|
-
"RTDetrPreTrainedModel",
|
|
1998
|
-
]
|
|
1839
|
+
__all__ = ["RTDetrForObjectDetection", "RTDetrModel", "RTDetrPreTrainedModel"]
|