transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,32 @@
|
|
|
1
|
+
# Copyright 2024 Baidu Inc and The HuggingFace Inc. team.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import math
|
|
1
15
|
import pathlib
|
|
16
|
+
from dataclasses import dataclass
|
|
2
17
|
from typing import Optional
|
|
3
18
|
|
|
4
19
|
import torch
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
from
|
|
20
|
+
import torch.nn.functional as F
|
|
21
|
+
import torchvision.transforms.v2.functional as tvF
|
|
22
|
+
from torch import nn
|
|
8
23
|
|
|
24
|
+
from ... import initialization as init
|
|
25
|
+
from ...activations import ACT2CLS, ACT2FN
|
|
26
|
+
from ...backbone_utils import load_backbone
|
|
9
27
|
from ...image_processing_utils import BatchFeature
|
|
10
28
|
from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, get_max_height_width
|
|
11
|
-
from ...image_transforms import center_to_corners_format
|
|
29
|
+
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
|
12
30
|
from ...image_utils import (
|
|
13
31
|
IMAGENET_DEFAULT_MEAN,
|
|
14
32
|
IMAGENET_DEFAULT_STD,
|
|
@@ -19,12 +37,25 @@ from ...image_utils import (
|
|
|
19
37
|
get_image_size,
|
|
20
38
|
validate_annotations,
|
|
21
39
|
)
|
|
40
|
+
from ...modeling_outputs import BaseModelOutput
|
|
41
|
+
from ...modeling_utils import PreTrainedModel
|
|
22
42
|
from ...processing_utils import Unpack
|
|
43
|
+
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
23
44
|
from ...utils import (
|
|
45
|
+
ModelOutput,
|
|
24
46
|
TensorType,
|
|
47
|
+
TransformersKwargs,
|
|
48
|
+
auto_docstring,
|
|
25
49
|
logging,
|
|
26
50
|
requires_backends,
|
|
51
|
+
torch_int,
|
|
27
52
|
)
|
|
53
|
+
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
54
|
+
from ..conditional_detr.modeling_conditional_detr import inverse_sigmoid
|
|
55
|
+
from ..deformable_detr.modeling_deformable_detr import DeformableDetrMultiscaleDeformableAttention
|
|
56
|
+
from ..detr.image_processing_detr_fast import DetrImageProcessorFast
|
|
57
|
+
from ..detr.modeling_detr import DetrFrozenBatchNorm2d, DetrMLPPredictionHead, DetrSelfAttention, replace_batch_norm
|
|
58
|
+
from .configuration_rt_detr import RTDetrConfig
|
|
28
59
|
from .image_processing_rt_detr import RTDetrImageProcessorKwargs
|
|
29
60
|
|
|
30
61
|
|
|
@@ -144,7 +175,7 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast):
|
|
|
144
175
|
return_segmentation_masks: bool,
|
|
145
176
|
do_resize: bool,
|
|
146
177
|
size: SizeDict,
|
|
147
|
-
interpolation: Optional["
|
|
178
|
+
interpolation: Optional["tvF.InterpolationMode"],
|
|
148
179
|
do_rescale: bool,
|
|
149
180
|
rescale_factor: float,
|
|
150
181
|
do_normalize: bool,
|
|
@@ -324,4 +355,1475 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast):
|
|
|
324
355
|
raise NotImplementedError("Panoptic segmentation post-processing is not implemented for RT-DETR yet.")
|
|
325
356
|
|
|
326
357
|
|
|
327
|
-
|
|
358
|
+
@dataclass
|
|
359
|
+
@auto_docstring(
|
|
360
|
+
custom_intro="""
|
|
361
|
+
Base class for outputs of the RTDetrDecoder. This class adds two attributes to
|
|
362
|
+
BaseModelOutputWithCrossAttentions, namely:
|
|
363
|
+
- a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
|
|
364
|
+
- a stacked tensor of intermediate reference points.
|
|
365
|
+
"""
|
|
366
|
+
)
|
|
367
|
+
class RTDetrDecoderOutput(ModelOutput):
|
|
368
|
+
r"""
|
|
369
|
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
|
370
|
+
Stacked intermediate hidden states (output of each layer of the decoder).
|
|
371
|
+
intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
|
|
372
|
+
Stacked intermediate logits (logits of each layer of the decoder).
|
|
373
|
+
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
|
|
374
|
+
Stacked intermediate reference points (reference points of each layer of the decoder).
|
|
375
|
+
intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
376
|
+
Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
|
|
377
|
+
initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
378
|
+
Stacked initial reference points (initial reference points of each layer of the decoder).
|
|
379
|
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
|
380
|
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
381
|
+
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
|
382
|
+
used to compute the weighted average in the cross-attention heads.
|
|
383
|
+
"""
|
|
384
|
+
|
|
385
|
+
last_hidden_state: torch.FloatTensor | None = None
|
|
386
|
+
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
387
|
+
intermediate_logits: torch.FloatTensor | None = None
|
|
388
|
+
intermediate_reference_points: torch.FloatTensor | None = None
|
|
389
|
+
intermediate_predicted_corners: torch.FloatTensor | None = None
|
|
390
|
+
initial_reference_points: torch.FloatTensor | None = None
|
|
391
|
+
hidden_states: tuple[torch.FloatTensor] | None = None
|
|
392
|
+
attentions: tuple[torch.FloatTensor] | None = None
|
|
393
|
+
cross_attentions: tuple[torch.FloatTensor] | None = None
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
@dataclass
|
|
397
|
+
@auto_docstring(
|
|
398
|
+
custom_intro="""
|
|
399
|
+
Base class for outputs of the RT-DETR encoder-decoder model.
|
|
400
|
+
"""
|
|
401
|
+
)
|
|
402
|
+
class RTDetrModelOutput(ModelOutput):
|
|
403
|
+
r"""
|
|
404
|
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
405
|
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
|
406
|
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
|
407
|
+
Stacked intermediate hidden states (output of each layer of the decoder).
|
|
408
|
+
intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
|
|
409
|
+
Stacked intermediate logits (logits of each layer of the decoder).
|
|
410
|
+
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
411
|
+
Stacked intermediate reference points (reference points of each layer of the decoder).
|
|
412
|
+
intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
413
|
+
Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
|
|
414
|
+
initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
|
415
|
+
Initial reference points used for the first decoder layer.
|
|
416
|
+
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
|
417
|
+
Initial reference points sent through the Transformer decoder.
|
|
418
|
+
enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
|
|
419
|
+
Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
|
|
420
|
+
picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e.
|
|
421
|
+
foreground and background).
|
|
422
|
+
enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`):
|
|
423
|
+
Logits of predicted bounding boxes coordinates in the encoder stage.
|
|
424
|
+
enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
|
425
|
+
Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
|
|
426
|
+
picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
|
|
427
|
+
foreground and background).
|
|
428
|
+
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
|
429
|
+
Logits of predicted bounding boxes coordinates in the first stage.
|
|
430
|
+
denoising_meta_values (`dict`):
|
|
431
|
+
Extra dictionary for the denoising related values.
|
|
432
|
+
"""
|
|
433
|
+
|
|
434
|
+
last_hidden_state: torch.FloatTensor | None = None
|
|
435
|
+
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
436
|
+
intermediate_logits: torch.FloatTensor | None = None
|
|
437
|
+
intermediate_reference_points: torch.FloatTensor | None = None
|
|
438
|
+
intermediate_predicted_corners: torch.FloatTensor | None = None
|
|
439
|
+
initial_reference_points: torch.FloatTensor | None = None
|
|
440
|
+
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
441
|
+
decoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
442
|
+
cross_attentions: tuple[torch.FloatTensor] | None = None
|
|
443
|
+
encoder_last_hidden_state: torch.FloatTensor | None = None
|
|
444
|
+
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
445
|
+
encoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
446
|
+
init_reference_points: torch.FloatTensor | None = None
|
|
447
|
+
enc_topk_logits: torch.FloatTensor | None = None
|
|
448
|
+
enc_topk_bboxes: torch.FloatTensor | None = None
|
|
449
|
+
enc_outputs_class: torch.FloatTensor | None = None
|
|
450
|
+
enc_outputs_coord_logits: torch.FloatTensor | None = None
|
|
451
|
+
denoising_meta_values: dict | None = None
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
@dataclass
|
|
455
|
+
@auto_docstring(
|
|
456
|
+
custom_intro="""
|
|
457
|
+
Output type of [`RTDetrForObjectDetection`].
|
|
458
|
+
"""
|
|
459
|
+
)
|
|
460
|
+
class RTDetrObjectDetectionOutput(ModelOutput):
|
|
461
|
+
r"""
|
|
462
|
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
|
463
|
+
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
|
|
464
|
+
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
|
|
465
|
+
scale-invariant IoU loss.
|
|
466
|
+
loss_dict (`Dict`, *optional*):
|
|
467
|
+
A dictionary containing the individual losses. Useful for logging.
|
|
468
|
+
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
|
|
469
|
+
Classification logits (including no-object) for all queries.
|
|
470
|
+
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
|
471
|
+
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
|
472
|
+
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
|
473
|
+
possible padding). You can use [`~RTDetrImageProcessor.post_process_object_detection`] to retrieve the
|
|
474
|
+
unnormalized (absolute) bounding boxes.
|
|
475
|
+
auxiliary_outputs (`list[Dict]`, *optional*):
|
|
476
|
+
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
|
|
477
|
+
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
|
|
478
|
+
`pred_boxes`) for each decoder layer.
|
|
479
|
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
480
|
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
|
481
|
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
|
482
|
+
Stacked intermediate hidden states (output of each layer of the decoder).
|
|
483
|
+
intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`):
|
|
484
|
+
Stacked intermediate logits (logits of each layer of the decoder).
|
|
485
|
+
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
486
|
+
Stacked intermediate reference points (reference points of each layer of the decoder).
|
|
487
|
+
intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
488
|
+
Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
|
|
489
|
+
initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
490
|
+
Stacked initial reference points (initial reference points of each layer of the decoder).
|
|
491
|
+
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
|
492
|
+
Initial reference points sent through the Transformer decoder.
|
|
493
|
+
enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
|
494
|
+
Logits of predicted bounding boxes coordinates in the encoder.
|
|
495
|
+
enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
|
496
|
+
Logits of predicted bounding boxes coordinates in the encoder.
|
|
497
|
+
enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
|
498
|
+
Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
|
|
499
|
+
picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
|
|
500
|
+
foreground and background).
|
|
501
|
+
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
|
502
|
+
Logits of predicted bounding boxes coordinates in the first stage.
|
|
503
|
+
denoising_meta_values (`dict`):
|
|
504
|
+
Extra dictionary for the denoising related values
|
|
505
|
+
"""
|
|
506
|
+
|
|
507
|
+
loss: torch.FloatTensor | None = None
|
|
508
|
+
loss_dict: dict | None = None
|
|
509
|
+
logits: torch.FloatTensor | None = None
|
|
510
|
+
pred_boxes: torch.FloatTensor | None = None
|
|
511
|
+
auxiliary_outputs: list[dict] | None = None
|
|
512
|
+
last_hidden_state: torch.FloatTensor | None = None
|
|
513
|
+
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
514
|
+
intermediate_logits: torch.FloatTensor | None = None
|
|
515
|
+
intermediate_reference_points: torch.FloatTensor | None = None
|
|
516
|
+
intermediate_predicted_corners: torch.FloatTensor | None = None
|
|
517
|
+
initial_reference_points: torch.FloatTensor | None = None
|
|
518
|
+
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
519
|
+
decoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
520
|
+
cross_attentions: tuple[torch.FloatTensor] | None = None
|
|
521
|
+
encoder_last_hidden_state: torch.FloatTensor | None = None
|
|
522
|
+
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
523
|
+
encoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
524
|
+
init_reference_points: tuple[torch.FloatTensor] | None = None
|
|
525
|
+
enc_topk_logits: torch.FloatTensor | None = None
|
|
526
|
+
enc_topk_bboxes: torch.FloatTensor | None = None
|
|
527
|
+
enc_outputs_class: torch.FloatTensor | None = None
|
|
528
|
+
enc_outputs_coord_logits: torch.FloatTensor | None = None
|
|
529
|
+
denoising_meta_values: dict | None = None
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
class RTDetrMLP(nn.Module):
|
|
533
|
+
def __init__(self, config: RTDetrConfig, hidden_size: int, intermediate_size: int, activation_function: str):
|
|
534
|
+
super().__init__()
|
|
535
|
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
|
536
|
+
self.fc2 = nn.Linear(intermediate_size, hidden_size)
|
|
537
|
+
self.activation_fn = ACT2FN[activation_function]
|
|
538
|
+
self.activation_dropout = config.activation_dropout
|
|
539
|
+
self.dropout = config.dropout
|
|
540
|
+
|
|
541
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
542
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
543
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
544
|
+
hidden_states = self.fc2(hidden_states)
|
|
545
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
546
|
+
return hidden_states
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
class RTDetrFrozenBatchNorm2d(DetrFrozenBatchNorm2d):
|
|
550
|
+
pass
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
class RTDetrSelfAttention(DetrSelfAttention):
|
|
554
|
+
pass
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
def get_contrastive_denoising_training_group(
|
|
558
|
+
targets,
|
|
559
|
+
num_classes,
|
|
560
|
+
num_queries,
|
|
561
|
+
class_embed,
|
|
562
|
+
num_denoising_queries=100,
|
|
563
|
+
label_noise_ratio=0.5,
|
|
564
|
+
box_noise_scale=1.0,
|
|
565
|
+
):
|
|
566
|
+
"""
|
|
567
|
+
Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
targets (`list[dict]`):
|
|
571
|
+
The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
|
|
572
|
+
num_classes (`int`):
|
|
573
|
+
Total number of classes in the dataset.
|
|
574
|
+
num_queries (`int`):
|
|
575
|
+
Number of query slots in the transformer.
|
|
576
|
+
class_embed (`callable`):
|
|
577
|
+
A function or a model layer to embed class labels.
|
|
578
|
+
num_denoising_queries (`int`, *optional*, defaults to 100):
|
|
579
|
+
Number of denoising queries.
|
|
580
|
+
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
|
581
|
+
Ratio of noise applied to labels.
|
|
582
|
+
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
|
583
|
+
Scale of noise applied to bounding boxes.
|
|
584
|
+
Returns:
|
|
585
|
+
`tuple` comprising various elements:
|
|
586
|
+
- **input_query_class** (`torch.FloatTensor`) --
|
|
587
|
+
Class queries with applied label noise.
|
|
588
|
+
- **input_query_bbox** (`torch.FloatTensor`) --
|
|
589
|
+
Bounding box queries with applied box noise.
|
|
590
|
+
- **attn_mask** (`torch.FloatTensor`) --
|
|
591
|
+
Attention mask for separating denoising and reconstruction queries.
|
|
592
|
+
- **denoising_meta_values** (`dict`) --
|
|
593
|
+
Metadata including denoising positive indices, number of groups, and split sizes.
|
|
594
|
+
"""
|
|
595
|
+
|
|
596
|
+
if num_denoising_queries <= 0:
|
|
597
|
+
return None, None, None, None
|
|
598
|
+
|
|
599
|
+
num_ground_truths = [len(t["class_labels"]) for t in targets]
|
|
600
|
+
device = targets[0]["class_labels"].device
|
|
601
|
+
|
|
602
|
+
max_gt_num = max(num_ground_truths)
|
|
603
|
+
if max_gt_num == 0:
|
|
604
|
+
return None, None, None, None
|
|
605
|
+
|
|
606
|
+
num_groups_denoising_queries = num_denoising_queries // max_gt_num
|
|
607
|
+
num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
|
|
608
|
+
# pad gt to max_num of a batch
|
|
609
|
+
batch_size = len(num_ground_truths)
|
|
610
|
+
|
|
611
|
+
input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
|
|
612
|
+
input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
|
|
613
|
+
pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
|
|
614
|
+
|
|
615
|
+
for i in range(batch_size):
|
|
616
|
+
num_gt = num_ground_truths[i]
|
|
617
|
+
if num_gt > 0:
|
|
618
|
+
input_query_class[i, :num_gt] = targets[i]["class_labels"]
|
|
619
|
+
input_query_bbox[i, :num_gt] = targets[i]["boxes"]
|
|
620
|
+
pad_gt_mask[i, :num_gt] = 1
|
|
621
|
+
# each group has positive and negative queries.
|
|
622
|
+
input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
|
|
623
|
+
input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
|
|
624
|
+
pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
|
|
625
|
+
# positive and negative mask
|
|
626
|
+
negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
|
|
627
|
+
negative_gt_mask[:, max_gt_num:] = 1
|
|
628
|
+
negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
|
|
629
|
+
positive_gt_mask = 1 - negative_gt_mask
|
|
630
|
+
# contrastive denoising training positive index
|
|
631
|
+
positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
|
|
632
|
+
denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
|
|
633
|
+
denoise_positive_idx = torch.split(
|
|
634
|
+
denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
|
|
635
|
+
)
|
|
636
|
+
# total denoising queries
|
|
637
|
+
num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
|
|
638
|
+
|
|
639
|
+
if label_noise_ratio > 0:
|
|
640
|
+
mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
|
|
641
|
+
# randomly put a new one here
|
|
642
|
+
new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
|
|
643
|
+
input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
|
|
644
|
+
|
|
645
|
+
if box_noise_scale > 0:
|
|
646
|
+
known_bbox = center_to_corners_format(input_query_bbox)
|
|
647
|
+
diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
|
|
648
|
+
rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
|
|
649
|
+
rand_part = torch.rand_like(input_query_bbox)
|
|
650
|
+
rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
|
|
651
|
+
rand_part *= rand_sign
|
|
652
|
+
known_bbox += rand_part * diff
|
|
653
|
+
known_bbox.clip_(min=0.0, max=1.0)
|
|
654
|
+
input_query_bbox = corners_to_center_format(known_bbox)
|
|
655
|
+
input_query_bbox = inverse_sigmoid(input_query_bbox)
|
|
656
|
+
|
|
657
|
+
input_query_class = class_embed(input_query_class)
|
|
658
|
+
|
|
659
|
+
target_size = num_denoising_queries + num_queries
|
|
660
|
+
attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
|
|
661
|
+
# match query cannot see the reconstruction
|
|
662
|
+
attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
|
|
663
|
+
|
|
664
|
+
# reconstructions cannot see each other
|
|
665
|
+
for i in range(num_groups_denoising_queries):
|
|
666
|
+
idx_block_start = max_gt_num * 2 * i
|
|
667
|
+
idx_block_end = max_gt_num * 2 * (i + 1)
|
|
668
|
+
attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
|
|
669
|
+
attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
|
|
670
|
+
|
|
671
|
+
denoising_meta_values = {
|
|
672
|
+
"dn_positive_idx": denoise_positive_idx,
|
|
673
|
+
"dn_num_group": num_groups_denoising_queries,
|
|
674
|
+
"dn_num_split": [num_denoising_queries, num_queries],
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
class RTDetrConvEncoder(nn.Module):
|
|
681
|
+
"""
|
|
682
|
+
Convolutional backbone using the modeling_rt_detr_resnet.py.
|
|
683
|
+
|
|
684
|
+
nn.BatchNorm2d layers are replaced by RTDetrFrozenBatchNorm2d as defined above.
|
|
685
|
+
https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetr_pytorch/src/nn/backbone/presnet.py#L142
|
|
686
|
+
"""
|
|
687
|
+
|
|
688
|
+
def __init__(self, config):
|
|
689
|
+
super().__init__()
|
|
690
|
+
|
|
691
|
+
backbone = load_backbone(config)
|
|
692
|
+
|
|
693
|
+
if config.freeze_backbone_batch_norms:
|
|
694
|
+
# replace batch norm by frozen batch norm
|
|
695
|
+
with torch.no_grad():
|
|
696
|
+
replace_batch_norm(backbone)
|
|
697
|
+
self.model = backbone
|
|
698
|
+
self.intermediate_channel_sizes = self.model.channels
|
|
699
|
+
|
|
700
|
+
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
|
701
|
+
# send pixel_values through the model to get list of feature maps
|
|
702
|
+
features = self.model(pixel_values).feature_maps
|
|
703
|
+
|
|
704
|
+
out = []
|
|
705
|
+
for feature_map in features:
|
|
706
|
+
# downsample pixel_mask to match shape of corresponding feature_map
|
|
707
|
+
mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
|
|
708
|
+
out.append((feature_map, mask))
|
|
709
|
+
return out
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
class RTDetrConvNormLayer(nn.Module):
|
|
713
|
+
def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
|
|
714
|
+
super().__init__()
|
|
715
|
+
self.conv = nn.Conv2d(
|
|
716
|
+
in_channels,
|
|
717
|
+
out_channels,
|
|
718
|
+
kernel_size,
|
|
719
|
+
stride,
|
|
720
|
+
padding=(kernel_size - 1) // 2 if padding is None else padding,
|
|
721
|
+
bias=False,
|
|
722
|
+
)
|
|
723
|
+
self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
|
|
724
|
+
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
|
725
|
+
|
|
726
|
+
def forward(self, hidden_state):
|
|
727
|
+
hidden_state = self.conv(hidden_state)
|
|
728
|
+
hidden_state = self.norm(hidden_state)
|
|
729
|
+
hidden_state = self.activation(hidden_state)
|
|
730
|
+
return hidden_state
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
class RTDetrEncoderLayer(nn.Module):
|
|
734
|
+
def __init__(self, config: RTDetrConfig):
|
|
735
|
+
super().__init__()
|
|
736
|
+
self.normalize_before = config.normalize_before
|
|
737
|
+
self.hidden_size = config.encoder_hidden_dim
|
|
738
|
+
|
|
739
|
+
# self-attention
|
|
740
|
+
self.self_attn = RTDetrSelfAttention(
|
|
741
|
+
config=config,
|
|
742
|
+
hidden_size=self.hidden_size,
|
|
743
|
+
num_attention_heads=config.num_attention_heads,
|
|
744
|
+
dropout=config.dropout,
|
|
745
|
+
)
|
|
746
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
747
|
+
self.dropout = config.dropout
|
|
748
|
+
self.mlp = RTDetrMLP(config, self.hidden_size, config.encoder_ffn_dim, config.encoder_activation_function)
|
|
749
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
750
|
+
|
|
751
|
+
def forward(
|
|
752
|
+
self,
|
|
753
|
+
hidden_states: torch.Tensor,
|
|
754
|
+
attention_mask: torch.Tensor,
|
|
755
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
756
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
757
|
+
) -> torch.Tensor:
|
|
758
|
+
"""
|
|
759
|
+
Args:
|
|
760
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
761
|
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
762
|
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
763
|
+
values.
|
|
764
|
+
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
765
|
+
Spatial position embeddings (2D positional encodings of image locations), to be added to both
|
|
766
|
+
the queries and keys in self-attention (but not to values).
|
|
767
|
+
"""
|
|
768
|
+
residual = hidden_states
|
|
769
|
+
if self.normalize_before:
|
|
770
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
771
|
+
|
|
772
|
+
hidden_states, _ = self.self_attn(
|
|
773
|
+
hidden_states=hidden_states,
|
|
774
|
+
attention_mask=attention_mask,
|
|
775
|
+
position_embeddings=spatial_position_embeddings,
|
|
776
|
+
**kwargs,
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
780
|
+
hidden_states = residual + hidden_states
|
|
781
|
+
if not self.normalize_before:
|
|
782
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
783
|
+
|
|
784
|
+
if self.normalize_before:
|
|
785
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
|
786
|
+
residual = hidden_states
|
|
787
|
+
|
|
788
|
+
hidden_states = self.mlp(hidden_states)
|
|
789
|
+
|
|
790
|
+
hidden_states = residual + hidden_states
|
|
791
|
+
if not self.normalize_before:
|
|
792
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
|
793
|
+
|
|
794
|
+
if self.training:
|
|
795
|
+
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
|
796
|
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
797
|
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
798
|
+
|
|
799
|
+
return hidden_states
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
class RTDetrRepVggBlock(nn.Module):
|
|
803
|
+
"""
|
|
804
|
+
RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
|
|
805
|
+
"""
|
|
806
|
+
|
|
807
|
+
def __init__(self, config: RTDetrConfig):
|
|
808
|
+
super().__init__()
|
|
809
|
+
|
|
810
|
+
activation = config.activation_function
|
|
811
|
+
hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion)
|
|
812
|
+
self.conv1 = RTDetrConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1)
|
|
813
|
+
self.conv2 = RTDetrConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0)
|
|
814
|
+
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
|
815
|
+
|
|
816
|
+
def forward(self, x):
|
|
817
|
+
y = self.conv1(x) + self.conv2(x)
|
|
818
|
+
return self.activation(y)
|
|
819
|
+
|
|
820
|
+
|
|
821
|
+
class RTDetrCSPRepLayer(nn.Module):
|
|
822
|
+
"""
|
|
823
|
+
Cross Stage Partial (CSP) network layer with RepVGG blocks.
|
|
824
|
+
"""
|
|
825
|
+
|
|
826
|
+
def __init__(self, config: RTDetrConfig):
|
|
827
|
+
super().__init__()
|
|
828
|
+
|
|
829
|
+
in_channels = config.encoder_hidden_dim * 2
|
|
830
|
+
out_channels = config.encoder_hidden_dim
|
|
831
|
+
num_blocks = 3
|
|
832
|
+
activation = config.activation_function
|
|
833
|
+
|
|
834
|
+
hidden_channels = int(out_channels * config.hidden_expansion)
|
|
835
|
+
self.conv1 = RTDetrConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
|
836
|
+
self.conv2 = RTDetrConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
|
837
|
+
self.bottlenecks = nn.Sequential(*[RTDetrRepVggBlock(config) for _ in range(num_blocks)])
|
|
838
|
+
if hidden_channels != out_channels:
|
|
839
|
+
self.conv3 = RTDetrConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
|
|
840
|
+
else:
|
|
841
|
+
self.conv3 = nn.Identity()
|
|
842
|
+
|
|
843
|
+
def forward(self, hidden_state):
|
|
844
|
+
hidden_state_1 = self.conv1(hidden_state)
|
|
845
|
+
hidden_state_1 = self.bottlenecks(hidden_state_1)
|
|
846
|
+
hidden_state_2 = self.conv2(hidden_state)
|
|
847
|
+
return self.conv3(hidden_state_1 + hidden_state_2)
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
class RTDetrMultiscaleDeformableAttention(DeformableDetrMultiscaleDeformableAttention):
|
|
851
|
+
pass
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
class RTDetrDecoderLayer(nn.Module):
|
|
855
|
+
def __init__(self, config: RTDetrConfig):
|
|
856
|
+
super().__init__()
|
|
857
|
+
self.hidden_size = config.d_model
|
|
858
|
+
|
|
859
|
+
# self-attention
|
|
860
|
+
self.self_attn = RTDetrSelfAttention(
|
|
861
|
+
config=config,
|
|
862
|
+
hidden_size=self.hidden_size,
|
|
863
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
864
|
+
dropout=config.attention_dropout,
|
|
865
|
+
)
|
|
866
|
+
self.dropout = config.dropout
|
|
867
|
+
|
|
868
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
869
|
+
# cross-attention
|
|
870
|
+
self.encoder_attn = RTDetrMultiscaleDeformableAttention(
|
|
871
|
+
config,
|
|
872
|
+
num_heads=config.decoder_attention_heads,
|
|
873
|
+
n_points=config.decoder_n_points,
|
|
874
|
+
)
|
|
875
|
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
876
|
+
# feedforward neural networks
|
|
877
|
+
self.mlp = RTDetrMLP(config, self.hidden_size, config.decoder_ffn_dim, config.decoder_activation_function)
|
|
878
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
879
|
+
|
|
880
|
+
def forward(
|
|
881
|
+
self,
|
|
882
|
+
hidden_states: torch.Tensor,
|
|
883
|
+
object_queries_position_embeddings: torch.Tensor | None = None,
|
|
884
|
+
reference_points=None,
|
|
885
|
+
spatial_shapes=None,
|
|
886
|
+
spatial_shapes_list=None,
|
|
887
|
+
level_start_index=None,
|
|
888
|
+
encoder_hidden_states: torch.Tensor | None = None,
|
|
889
|
+
encoder_attention_mask: torch.Tensor | None = None,
|
|
890
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
891
|
+
) -> torch.Tensor:
|
|
892
|
+
"""
|
|
893
|
+
Args:
|
|
894
|
+
hidden_states (`torch.FloatTensor`):
|
|
895
|
+
Input to the layer of shape `(batch, seq_len, hidden_size)`.
|
|
896
|
+
object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
897
|
+
Position embeddings for the object query slots. These are added to both queries and keys
|
|
898
|
+
in the self-attention layer (not values).
|
|
899
|
+
reference_points (`torch.FloatTensor`, *optional*):
|
|
900
|
+
Reference points.
|
|
901
|
+
spatial_shapes (`torch.LongTensor`, *optional*):
|
|
902
|
+
Spatial shapes.
|
|
903
|
+
level_start_index (`torch.LongTensor`, *optional*):
|
|
904
|
+
Level start index.
|
|
905
|
+
encoder_hidden_states (`torch.FloatTensor`):
|
|
906
|
+
cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
907
|
+
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
908
|
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
909
|
+
values.
|
|
910
|
+
"""
|
|
911
|
+
residual = hidden_states
|
|
912
|
+
|
|
913
|
+
# Self Attention
|
|
914
|
+
hidden_states, _ = self.self_attn(
|
|
915
|
+
hidden_states=hidden_states,
|
|
916
|
+
attention_mask=encoder_attention_mask,
|
|
917
|
+
position_embeddings=object_queries_position_embeddings,
|
|
918
|
+
**kwargs,
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
922
|
+
hidden_states = residual + hidden_states
|
|
923
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
924
|
+
|
|
925
|
+
residual = hidden_states
|
|
926
|
+
|
|
927
|
+
# Cross-Attention
|
|
928
|
+
hidden_states, _ = self.encoder_attn(
|
|
929
|
+
hidden_states=hidden_states,
|
|
930
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
931
|
+
position_embeddings=object_queries_position_embeddings,
|
|
932
|
+
reference_points=reference_points,
|
|
933
|
+
spatial_shapes=spatial_shapes,
|
|
934
|
+
spatial_shapes_list=spatial_shapes_list,
|
|
935
|
+
level_start_index=level_start_index,
|
|
936
|
+
**kwargs,
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
940
|
+
hidden_states = residual + hidden_states
|
|
941
|
+
|
|
942
|
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
943
|
+
|
|
944
|
+
# Fully Connected
|
|
945
|
+
residual = hidden_states
|
|
946
|
+
hidden_states = self.mlp(hidden_states)
|
|
947
|
+
hidden_states = residual + hidden_states
|
|
948
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
|
949
|
+
|
|
950
|
+
return hidden_states
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
class RTDetrSinePositionEmbedding(nn.Module):
|
|
954
|
+
"""
|
|
955
|
+
2D sinusoidal position embedding used in RT-DETR hybrid encoder.
|
|
956
|
+
"""
|
|
957
|
+
|
|
958
|
+
def __init__(self, embed_dim: int = 256, temperature: int = 10000):
|
|
959
|
+
super().__init__()
|
|
960
|
+
self.embed_dim = embed_dim
|
|
961
|
+
self.temperature = temperature
|
|
962
|
+
|
|
963
|
+
@compile_compatible_method_lru_cache(maxsize=32)
|
|
964
|
+
def forward(
|
|
965
|
+
self,
|
|
966
|
+
width: int,
|
|
967
|
+
height: int,
|
|
968
|
+
device: torch.device | str,
|
|
969
|
+
dtype: torch.dtype,
|
|
970
|
+
) -> torch.Tensor:
|
|
971
|
+
"""
|
|
972
|
+
Generate 2D sinusoidal position embeddings.
|
|
973
|
+
|
|
974
|
+
Returns:
|
|
975
|
+
Position embeddings of shape (1, height*width, embed_dim)
|
|
976
|
+
"""
|
|
977
|
+
grid_w = torch.arange(torch_int(width), device=device).to(dtype)
|
|
978
|
+
grid_h = torch.arange(torch_int(height), device=device).to(dtype)
|
|
979
|
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
|
980
|
+
if self.embed_dim % 4 != 0:
|
|
981
|
+
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
|
|
982
|
+
pos_dim = self.embed_dim // 4
|
|
983
|
+
omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
|
|
984
|
+
omega = 1.0 / (self.temperature**omega)
|
|
985
|
+
|
|
986
|
+
out_w = grid_w.flatten()[..., None] @ omega[None]
|
|
987
|
+
out_h = grid_h.flatten()[..., None] @ omega[None]
|
|
988
|
+
|
|
989
|
+
return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
|
|
990
|
+
|
|
991
|
+
|
|
992
|
+
class RTDetrAIFILayer(nn.Module):
|
|
993
|
+
"""
|
|
994
|
+
AIFI (Attention-based Intra-scale Feature Interaction) layer used in RT-DETR hybrid encoder.
|
|
995
|
+
"""
|
|
996
|
+
|
|
997
|
+
def __init__(self, config: RTDetrConfig):
|
|
998
|
+
super().__init__()
|
|
999
|
+
self.config = config
|
|
1000
|
+
self.encoder_hidden_dim = config.encoder_hidden_dim
|
|
1001
|
+
self.eval_size = config.eval_size
|
|
1002
|
+
|
|
1003
|
+
self.position_embedding = RTDetrSinePositionEmbedding(
|
|
1004
|
+
embed_dim=self.encoder_hidden_dim,
|
|
1005
|
+
temperature=config.positional_encoding_temperature,
|
|
1006
|
+
)
|
|
1007
|
+
self.layers = nn.ModuleList([RTDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
1008
|
+
|
|
1009
|
+
def forward(
|
|
1010
|
+
self,
|
|
1011
|
+
hidden_states: torch.Tensor,
|
|
1012
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1013
|
+
) -> torch.Tensor:
|
|
1014
|
+
"""
|
|
1015
|
+
Args:
|
|
1016
|
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
|
|
1017
|
+
Feature map to process.
|
|
1018
|
+
"""
|
|
1019
|
+
batch_size = hidden_states.shape[0]
|
|
1020
|
+
height, width = hidden_states.shape[2:]
|
|
1021
|
+
|
|
1022
|
+
hidden_states = hidden_states.flatten(2).permute(0, 2, 1)
|
|
1023
|
+
|
|
1024
|
+
if self.training or self.eval_size is None:
|
|
1025
|
+
pos_embed = self.position_embedding(
|
|
1026
|
+
width=width,
|
|
1027
|
+
height=height,
|
|
1028
|
+
device=hidden_states.device,
|
|
1029
|
+
dtype=hidden_states.dtype,
|
|
1030
|
+
)
|
|
1031
|
+
else:
|
|
1032
|
+
pos_embed = None
|
|
1033
|
+
|
|
1034
|
+
for layer in self.layers:
|
|
1035
|
+
hidden_states = layer(
|
|
1036
|
+
hidden_states,
|
|
1037
|
+
attention_mask=None,
|
|
1038
|
+
spatial_position_embeddings=pos_embed,
|
|
1039
|
+
**kwargs,
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
hidden_states = (
|
|
1043
|
+
hidden_states.permute(0, 2, 1).reshape(batch_size, self.encoder_hidden_dim, height, width).contiguous()
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
return hidden_states
|
|
1047
|
+
|
|
1048
|
+
|
|
1049
|
+
class RTDetrMLPPredictionHead(DetrMLPPredictionHead):
|
|
1050
|
+
pass
|
|
1051
|
+
|
|
1052
|
+
|
|
1053
|
+
@auto_docstring
|
|
1054
|
+
class RTDetrPreTrainedModel(PreTrainedModel):
|
|
1055
|
+
config: RTDetrConfig
|
|
1056
|
+
base_model_prefix = "rt_detr"
|
|
1057
|
+
main_input_name = "pixel_values"
|
|
1058
|
+
input_modalities = ("image",)
|
|
1059
|
+
_no_split_modules = [r"RTDetrHybridEncoder", r"RTDetrDecoderLayer"]
|
|
1060
|
+
_supports_sdpa = True
|
|
1061
|
+
_supports_flash_attn = True
|
|
1062
|
+
_supports_attention_backend = True
|
|
1063
|
+
_supports_flex_attn = True
|
|
1064
|
+
|
|
1065
|
+
@torch.no_grad()
|
|
1066
|
+
def _init_weights(self, module):
|
|
1067
|
+
"""Initialize the weights"""
|
|
1068
|
+
if isinstance(module, RTDetrForObjectDetection):
|
|
1069
|
+
if module.model.decoder.class_embed is not None:
|
|
1070
|
+
for layer in module.model.decoder.class_embed:
|
|
1071
|
+
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
|
1072
|
+
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
|
1073
|
+
init.xavier_uniform_(layer.weight)
|
|
1074
|
+
init.constant_(layer.bias, bias)
|
|
1075
|
+
|
|
1076
|
+
if module.model.decoder.bbox_embed is not None:
|
|
1077
|
+
for layer in module.model.decoder.bbox_embed:
|
|
1078
|
+
init.constant_(layer.layers[-1].weight, 0)
|
|
1079
|
+
init.constant_(layer.layers[-1].bias, 0)
|
|
1080
|
+
|
|
1081
|
+
elif isinstance(module, RTDetrMultiscaleDeformableAttention):
|
|
1082
|
+
init.constant_(module.sampling_offsets.weight, 0.0)
|
|
1083
|
+
default_dtype = torch.get_default_dtype()
|
|
1084
|
+
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
|
|
1085
|
+
2.0 * math.pi / module.n_heads
|
|
1086
|
+
)
|
|
1087
|
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
|
1088
|
+
grid_init = (
|
|
1089
|
+
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
|
1090
|
+
.view(module.n_heads, 1, 1, 2)
|
|
1091
|
+
.repeat(1, module.n_levels, module.n_points, 1)
|
|
1092
|
+
)
|
|
1093
|
+
for i in range(module.n_points):
|
|
1094
|
+
grid_init[:, :, i, :] *= i + 1
|
|
1095
|
+
|
|
1096
|
+
init.copy_(module.sampling_offsets.bias, grid_init.view(-1))
|
|
1097
|
+
init.constant_(module.attention_weights.weight, 0.0)
|
|
1098
|
+
init.constant_(module.attention_weights.bias, 0.0)
|
|
1099
|
+
init.xavier_uniform_(module.value_proj.weight)
|
|
1100
|
+
init.constant_(module.value_proj.bias, 0.0)
|
|
1101
|
+
init.xavier_uniform_(module.output_proj.weight)
|
|
1102
|
+
init.constant_(module.output_proj.bias, 0.0)
|
|
1103
|
+
|
|
1104
|
+
elif isinstance(module, RTDetrModel):
|
|
1105
|
+
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
|
1106
|
+
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
|
1107
|
+
init.xavier_uniform_(module.enc_score_head.weight)
|
|
1108
|
+
init.constant_(module.enc_score_head.bias, bias)
|
|
1109
|
+
|
|
1110
|
+
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
|
1111
|
+
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
1112
|
+
if module.bias is not None:
|
|
1113
|
+
init.zeros_(module.bias)
|
|
1114
|
+
if getattr(module, "running_mean", None) is not None:
|
|
1115
|
+
init.zeros_(module.running_mean)
|
|
1116
|
+
init.ones_(module.running_var)
|
|
1117
|
+
init.zeros_(module.num_batches_tracked)
|
|
1118
|
+
|
|
1119
|
+
elif isinstance(module, nn.LayerNorm):
|
|
1120
|
+
init.ones_(module.weight)
|
|
1121
|
+
init.zeros_(module.bias)
|
|
1122
|
+
|
|
1123
|
+
if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
|
|
1124
|
+
init.xavier_uniform_(module.weight_embedding.weight)
|
|
1125
|
+
if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
|
|
1126
|
+
init.xavier_uniform_(module.denoising_class_embed.weight)
|
|
1127
|
+
|
|
1128
|
+
|
|
1129
|
+
class RTDetrHybridEncoder(RTDetrPreTrainedModel):
|
|
1130
|
+
"""
|
|
1131
|
+
Hybrid encoder consisting of AIFI (Attention-based Intra-scale Feature Interaction) layers,
|
|
1132
|
+
a top-down Feature Pyramid Network (FPN) and a bottom-up Path Aggregation Network (PAN).
|
|
1133
|
+
More details on the paper: https://huggingface.co/papers/2304.08069
|
|
1134
|
+
|
|
1135
|
+
Args:
|
|
1136
|
+
config: RTDetrConfig
|
|
1137
|
+
"""
|
|
1138
|
+
|
|
1139
|
+
_can_record_outputs = {
|
|
1140
|
+
"hidden_states": RTDetrAIFILayer,
|
|
1141
|
+
"attentions": RTDetrSelfAttention,
|
|
1142
|
+
}
|
|
1143
|
+
|
|
1144
|
+
def __init__(self, config: RTDetrConfig):
|
|
1145
|
+
super().__init__(config)
|
|
1146
|
+
self.config = config
|
|
1147
|
+
self.in_channels = config.encoder_in_channels
|
|
1148
|
+
self.feat_strides = config.feat_strides
|
|
1149
|
+
self.encoder_hidden_dim = config.encoder_hidden_dim
|
|
1150
|
+
self.encode_proj_layers = config.encode_proj_layers
|
|
1151
|
+
self.positional_encoding_temperature = config.positional_encoding_temperature
|
|
1152
|
+
self.eval_size = config.eval_size
|
|
1153
|
+
self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
|
|
1154
|
+
self.out_strides = self.feat_strides
|
|
1155
|
+
self.num_fpn_stages = len(self.in_channels) - 1
|
|
1156
|
+
self.num_pan_stages = len(self.in_channels) - 1
|
|
1157
|
+
|
|
1158
|
+
# AIFI (Attention-based Intra-scale Feature Interaction) layers
|
|
1159
|
+
self.aifi = nn.ModuleList([RTDetrAIFILayer(config) for _ in range(len(self.encode_proj_layers))])
|
|
1160
|
+
|
|
1161
|
+
# top-down FPN
|
|
1162
|
+
self.lateral_convs = nn.ModuleList()
|
|
1163
|
+
self.fpn_blocks = nn.ModuleList()
|
|
1164
|
+
for _ in range(self.num_fpn_stages):
|
|
1165
|
+
lateral_conv = RTDetrConvNormLayer(
|
|
1166
|
+
config,
|
|
1167
|
+
in_channels=self.encoder_hidden_dim,
|
|
1168
|
+
out_channels=self.encoder_hidden_dim,
|
|
1169
|
+
kernel_size=1,
|
|
1170
|
+
stride=1,
|
|
1171
|
+
activation=config.activation_function,
|
|
1172
|
+
)
|
|
1173
|
+
fpn_block = RTDetrCSPRepLayer(config)
|
|
1174
|
+
self.lateral_convs.append(lateral_conv)
|
|
1175
|
+
self.fpn_blocks.append(fpn_block)
|
|
1176
|
+
|
|
1177
|
+
# bottom-up PAN
|
|
1178
|
+
self.downsample_convs = nn.ModuleList()
|
|
1179
|
+
self.pan_blocks = nn.ModuleList()
|
|
1180
|
+
for _ in range(self.num_pan_stages):
|
|
1181
|
+
downsample_conv = RTDetrConvNormLayer(
|
|
1182
|
+
config,
|
|
1183
|
+
in_channels=self.encoder_hidden_dim,
|
|
1184
|
+
out_channels=self.encoder_hidden_dim,
|
|
1185
|
+
kernel_size=3,
|
|
1186
|
+
stride=2,
|
|
1187
|
+
activation=config.activation_function,
|
|
1188
|
+
)
|
|
1189
|
+
pan_block = RTDetrCSPRepLayer(config)
|
|
1190
|
+
self.downsample_convs.append(downsample_conv)
|
|
1191
|
+
self.pan_blocks.append(pan_block)
|
|
1192
|
+
|
|
1193
|
+
self.post_init()
|
|
1194
|
+
|
|
1195
|
+
@check_model_inputs(tie_last_hidden_states=False)
|
|
1196
|
+
def forward(
|
|
1197
|
+
self,
|
|
1198
|
+
inputs_embeds=None,
|
|
1199
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1200
|
+
) -> BaseModelOutput:
|
|
1201
|
+
r"""
|
|
1202
|
+
Args:
|
|
1203
|
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
1204
|
+
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
|
1205
|
+
"""
|
|
1206
|
+
feature_maps = inputs_embeds
|
|
1207
|
+
|
|
1208
|
+
# AIFI: Apply transformer encoder to specified feature levels
|
|
1209
|
+
if self.config.encoder_layers > 0:
|
|
1210
|
+
for i, enc_ind in enumerate(self.encode_proj_layers):
|
|
1211
|
+
feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs)
|
|
1212
|
+
|
|
1213
|
+
# top-down FPN
|
|
1214
|
+
fpn_feature_maps = [feature_maps[-1]]
|
|
1215
|
+
for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
|
|
1216
|
+
backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1]
|
|
1217
|
+
top_fpn_feature_map = fpn_feature_maps[-1]
|
|
1218
|
+
# apply lateral block
|
|
1219
|
+
top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
|
|
1220
|
+
fpn_feature_maps[-1] = top_fpn_feature_map
|
|
1221
|
+
# apply fpn block
|
|
1222
|
+
top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
|
|
1223
|
+
fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
|
|
1224
|
+
new_fpn_feature_map = fpn_block(fused_feature_map)
|
|
1225
|
+
fpn_feature_maps.append(new_fpn_feature_map)
|
|
1226
|
+
|
|
1227
|
+
fpn_feature_maps.reverse()
|
|
1228
|
+
|
|
1229
|
+
# bottom-up PAN
|
|
1230
|
+
pan_feature_maps = [fpn_feature_maps[0]]
|
|
1231
|
+
for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
|
|
1232
|
+
top_pan_feature_map = pan_feature_maps[-1]
|
|
1233
|
+
fpn_feature_map = fpn_feature_maps[idx + 1]
|
|
1234
|
+
downsampled_feature_map = downsample_conv(top_pan_feature_map)
|
|
1235
|
+
fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
|
|
1236
|
+
new_pan_feature_map = pan_block(fused_feature_map)
|
|
1237
|
+
pan_feature_maps.append(new_pan_feature_map)
|
|
1238
|
+
|
|
1239
|
+
return BaseModelOutput(last_hidden_state=pan_feature_maps)
|
|
1240
|
+
|
|
1241
|
+
|
|
1242
|
+
class RTDetrDecoder(RTDetrPreTrainedModel):
|
|
1243
|
+
_can_record_outputs = {
|
|
1244
|
+
"hidden_states": RTDetrDecoderLayer,
|
|
1245
|
+
"attentions": RTDetrSelfAttention,
|
|
1246
|
+
"cross_attentions": RTDetrMultiscaleDeformableAttention,
|
|
1247
|
+
}
|
|
1248
|
+
|
|
1249
|
+
def __init__(self, config: RTDetrConfig):
|
|
1250
|
+
super().__init__(config)
|
|
1251
|
+
|
|
1252
|
+
self.dropout = config.dropout
|
|
1253
|
+
self.layers = nn.ModuleList([RTDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
|
|
1254
|
+
self.query_pos_head = RTDetrMLPPredictionHead(4, 2 * config.d_model, config.d_model, num_layers=2)
|
|
1255
|
+
|
|
1256
|
+
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
|
|
1257
|
+
self.bbox_embed = None
|
|
1258
|
+
self.class_embed = None
|
|
1259
|
+
|
|
1260
|
+
# Initialize weights and apply final processing
|
|
1261
|
+
self.post_init()
|
|
1262
|
+
|
|
1263
|
+
@check_model_inputs()
|
|
1264
|
+
def forward(
|
|
1265
|
+
self,
|
|
1266
|
+
inputs_embeds=None,
|
|
1267
|
+
encoder_hidden_states=None,
|
|
1268
|
+
encoder_attention_mask=None,
|
|
1269
|
+
reference_points=None,
|
|
1270
|
+
spatial_shapes=None,
|
|
1271
|
+
spatial_shapes_list=None,
|
|
1272
|
+
level_start_index=None,
|
|
1273
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1274
|
+
):
|
|
1275
|
+
r"""
|
|
1276
|
+
Args:
|
|
1277
|
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
1278
|
+
The query embeddings that are passed into the decoder.
|
|
1279
|
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1280
|
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
|
1281
|
+
of the decoder.
|
|
1282
|
+
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1283
|
+
Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
|
|
1284
|
+
in `[0, 1]`:
|
|
1285
|
+
- 1 for pixels that are real (i.e. **not masked**),
|
|
1286
|
+
- 0 for pixels that are padding (i.e. **masked**).
|
|
1287
|
+
reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
|
|
1288
|
+
Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
|
|
1289
|
+
spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
|
|
1290
|
+
Spatial shapes of the feature maps.
|
|
1291
|
+
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
|
|
1292
|
+
Indexes for the start of each feature level. In range `[0, sequence_length]`.
|
|
1293
|
+
"""
|
|
1294
|
+
if inputs_embeds is not None:
|
|
1295
|
+
hidden_states = inputs_embeds
|
|
1296
|
+
|
|
1297
|
+
# decoder layers
|
|
1298
|
+
intermediate = ()
|
|
1299
|
+
intermediate_reference_points = ()
|
|
1300
|
+
intermediate_logits = ()
|
|
1301
|
+
|
|
1302
|
+
reference_points = F.sigmoid(reference_points)
|
|
1303
|
+
|
|
1304
|
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L252
|
|
1305
|
+
for idx, decoder_layer in enumerate(self.layers):
|
|
1306
|
+
reference_points_input = reference_points.unsqueeze(2)
|
|
1307
|
+
object_queries_position_embeddings = self.query_pos_head(reference_points)
|
|
1308
|
+
|
|
1309
|
+
hidden_states = decoder_layer(
|
|
1310
|
+
hidden_states,
|
|
1311
|
+
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
1312
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
1313
|
+
reference_points=reference_points_input,
|
|
1314
|
+
spatial_shapes=spatial_shapes,
|
|
1315
|
+
spatial_shapes_list=spatial_shapes_list,
|
|
1316
|
+
level_start_index=level_start_index,
|
|
1317
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
1318
|
+
**kwargs,
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
# hack implementation for iterative bounding box refinement
|
|
1322
|
+
if self.bbox_embed is not None:
|
|
1323
|
+
predicted_corners = self.bbox_embed[idx](hidden_states)
|
|
1324
|
+
new_reference_points = F.sigmoid(predicted_corners + inverse_sigmoid(reference_points))
|
|
1325
|
+
reference_points = new_reference_points.detach()
|
|
1326
|
+
|
|
1327
|
+
intermediate += (hidden_states,)
|
|
1328
|
+
intermediate_reference_points += (
|
|
1329
|
+
(new_reference_points,) if self.bbox_embed is not None else (reference_points,)
|
|
1330
|
+
)
|
|
1331
|
+
|
|
1332
|
+
if self.class_embed is not None:
|
|
1333
|
+
logits = self.class_embed[idx](hidden_states)
|
|
1334
|
+
intermediate_logits += (logits,)
|
|
1335
|
+
|
|
1336
|
+
# Keep batch_size as first dimension
|
|
1337
|
+
intermediate = torch.stack(intermediate, dim=1)
|
|
1338
|
+
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
|
|
1339
|
+
if self.class_embed is not None:
|
|
1340
|
+
intermediate_logits = torch.stack(intermediate_logits, dim=1)
|
|
1341
|
+
|
|
1342
|
+
return RTDetrDecoderOutput(
|
|
1343
|
+
last_hidden_state=hidden_states,
|
|
1344
|
+
intermediate_hidden_states=intermediate,
|
|
1345
|
+
intermediate_logits=intermediate_logits,
|
|
1346
|
+
intermediate_reference_points=intermediate_reference_points,
|
|
1347
|
+
)
|
|
1348
|
+
|
|
1349
|
+
|
|
1350
|
+
@auto_docstring(
|
|
1351
|
+
custom_intro="""
|
|
1352
|
+
RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top.
|
|
1353
|
+
"""
|
|
1354
|
+
)
|
|
1355
|
+
class RTDetrModel(RTDetrPreTrainedModel):
|
|
1356
|
+
def __init__(self, config: RTDetrConfig):
|
|
1357
|
+
super().__init__(config)
|
|
1358
|
+
|
|
1359
|
+
# Create backbone
|
|
1360
|
+
self.backbone = RTDetrConvEncoder(config)
|
|
1361
|
+
intermediate_channel_sizes = self.backbone.intermediate_channel_sizes
|
|
1362
|
+
|
|
1363
|
+
# Create encoder input projection layers
|
|
1364
|
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py#L212
|
|
1365
|
+
num_backbone_outs = len(intermediate_channel_sizes)
|
|
1366
|
+
encoder_input_proj_list = []
|
|
1367
|
+
for i in range(num_backbone_outs):
|
|
1368
|
+
in_channels = intermediate_channel_sizes[i]
|
|
1369
|
+
encoder_input_proj_list.append(
|
|
1370
|
+
nn.Sequential(
|
|
1371
|
+
nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
|
|
1372
|
+
nn.BatchNorm2d(config.encoder_hidden_dim),
|
|
1373
|
+
)
|
|
1374
|
+
)
|
|
1375
|
+
self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list)
|
|
1376
|
+
|
|
1377
|
+
# Create encoder
|
|
1378
|
+
self.encoder = RTDetrHybridEncoder(config)
|
|
1379
|
+
|
|
1380
|
+
# denoising part
|
|
1381
|
+
if config.num_denoising > 0:
|
|
1382
|
+
self.denoising_class_embed = nn.Embedding(
|
|
1383
|
+
config.num_labels + 1, config.d_model, padding_idx=config.num_labels
|
|
1384
|
+
)
|
|
1385
|
+
|
|
1386
|
+
# decoder embedding
|
|
1387
|
+
if config.learn_initial_query:
|
|
1388
|
+
self.weight_embedding = nn.Embedding(config.num_queries, config.d_model)
|
|
1389
|
+
|
|
1390
|
+
# encoder head
|
|
1391
|
+
self.enc_output = nn.Sequential(
|
|
1392
|
+
nn.Linear(config.d_model, config.d_model),
|
|
1393
|
+
nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
|
|
1394
|
+
)
|
|
1395
|
+
self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
|
|
1396
|
+
self.enc_bbox_head = RTDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
|
|
1397
|
+
|
|
1398
|
+
# init encoder output anchors and valid_mask
|
|
1399
|
+
if config.anchor_image_size:
|
|
1400
|
+
self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype)
|
|
1401
|
+
|
|
1402
|
+
# Create decoder input projection layers
|
|
1403
|
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
|
|
1404
|
+
num_backbone_outs = len(config.decoder_in_channels)
|
|
1405
|
+
decoder_input_proj_list = []
|
|
1406
|
+
for i in range(num_backbone_outs):
|
|
1407
|
+
in_channels = config.decoder_in_channels[i]
|
|
1408
|
+
decoder_input_proj_list.append(
|
|
1409
|
+
nn.Sequential(
|
|
1410
|
+
nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
|
|
1411
|
+
nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
|
|
1412
|
+
)
|
|
1413
|
+
)
|
|
1414
|
+
for _ in range(config.num_feature_levels - num_backbone_outs):
|
|
1415
|
+
decoder_input_proj_list.append(
|
|
1416
|
+
nn.Sequential(
|
|
1417
|
+
nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False),
|
|
1418
|
+
nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
|
|
1419
|
+
)
|
|
1420
|
+
)
|
|
1421
|
+
in_channels = config.d_model
|
|
1422
|
+
self.decoder_input_proj = nn.ModuleList(decoder_input_proj_list)
|
|
1423
|
+
|
|
1424
|
+
# decoder
|
|
1425
|
+
self.decoder = RTDetrDecoder(config)
|
|
1426
|
+
|
|
1427
|
+
self.post_init()
|
|
1428
|
+
|
|
1429
|
+
def freeze_backbone(self):
|
|
1430
|
+
for param in self.backbone.parameters():
|
|
1431
|
+
param.requires_grad_(False)
|
|
1432
|
+
|
|
1433
|
+
def unfreeze_backbone(self):
|
|
1434
|
+
for param in self.backbone.parameters():
|
|
1435
|
+
param.requires_grad_(True)
|
|
1436
|
+
|
|
1437
|
+
@compile_compatible_method_lru_cache(maxsize=32)
|
|
1438
|
+
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
|
|
1439
|
+
if spatial_shapes is None:
|
|
1440
|
+
spatial_shapes = [
|
|
1441
|
+
[int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)]
|
|
1442
|
+
for s in self.config.feat_strides
|
|
1443
|
+
]
|
|
1444
|
+
anchors = []
|
|
1445
|
+
for level, (height, width) in enumerate(spatial_shapes):
|
|
1446
|
+
grid_y, grid_x = torch.meshgrid(
|
|
1447
|
+
torch.arange(end=height, device=device).to(dtype),
|
|
1448
|
+
torch.arange(end=width, device=device).to(dtype),
|
|
1449
|
+
indexing="ij",
|
|
1450
|
+
)
|
|
1451
|
+
grid_xy = torch.stack([grid_x, grid_y], -1)
|
|
1452
|
+
grid_xy = grid_xy.unsqueeze(0) + 0.5
|
|
1453
|
+
grid_xy[..., 0] /= width
|
|
1454
|
+
grid_xy[..., 1] /= height
|
|
1455
|
+
wh = torch.ones_like(grid_xy) * grid_size * (2.0**level)
|
|
1456
|
+
anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
|
|
1457
|
+
# define the valid range for anchor coordinates
|
|
1458
|
+
eps = 1e-2
|
|
1459
|
+
anchors = torch.concat(anchors, 1)
|
|
1460
|
+
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
|
|
1461
|
+
anchors = torch.log(anchors / (1 - anchors))
|
|
1462
|
+
anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device))
|
|
1463
|
+
|
|
1464
|
+
return anchors, valid_mask
|
|
1465
|
+
|
|
1466
|
+
@auto_docstring
|
|
1467
|
+
@can_return_tuple
|
|
1468
|
+
def forward(
|
|
1469
|
+
self,
|
|
1470
|
+
pixel_values: torch.FloatTensor,
|
|
1471
|
+
pixel_mask: torch.LongTensor | None = None,
|
|
1472
|
+
encoder_outputs: torch.FloatTensor | None = None,
|
|
1473
|
+
inputs_embeds: torch.FloatTensor | None = None,
|
|
1474
|
+
labels: list[dict] | None = None,
|
|
1475
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1476
|
+
) -> tuple[torch.FloatTensor] | RTDetrModelOutput:
|
|
1477
|
+
r"""
|
|
1478
|
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1479
|
+
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
1480
|
+
can choose to directly pass a flattened representation of an image.
|
|
1481
|
+
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1482
|
+
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
|
1483
|
+
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
|
1484
|
+
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
|
|
1485
|
+
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
|
|
1486
|
+
|
|
1487
|
+
Examples:
|
|
1488
|
+
|
|
1489
|
+
```python
|
|
1490
|
+
>>> from transformers import AutoImageProcessor, RTDetrModel
|
|
1491
|
+
>>> from PIL import Image
|
|
1492
|
+
>>> import requests
|
|
1493
|
+
|
|
1494
|
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
1495
|
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
1496
|
+
|
|
1497
|
+
>>> image_processor = AutoImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
|
|
1498
|
+
>>> model = RTDetrModel.from_pretrained("PekingU/rtdetr_r50vd")
|
|
1499
|
+
|
|
1500
|
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
|
1501
|
+
|
|
1502
|
+
>>> outputs = model(**inputs)
|
|
1503
|
+
|
|
1504
|
+
>>> last_hidden_states = outputs.last_hidden_state
|
|
1505
|
+
>>> list(last_hidden_states.shape)
|
|
1506
|
+
[1, 300, 256]
|
|
1507
|
+
```"""
|
|
1508
|
+
if pixel_values is None and inputs_embeds is None:
|
|
1509
|
+
raise ValueError("You have to specify either pixel_values or inputs_embeds")
|
|
1510
|
+
|
|
1511
|
+
if inputs_embeds is None:
|
|
1512
|
+
batch_size, num_channels, height, width = pixel_values.shape
|
|
1513
|
+
device = pixel_values.device
|
|
1514
|
+
if pixel_mask is None:
|
|
1515
|
+
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
|
1516
|
+
features = self.backbone(pixel_values, pixel_mask)
|
|
1517
|
+
proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
|
|
1518
|
+
else:
|
|
1519
|
+
batch_size = inputs_embeds.shape[0]
|
|
1520
|
+
device = inputs_embeds.device
|
|
1521
|
+
proj_feats = inputs_embeds
|
|
1522
|
+
|
|
1523
|
+
if encoder_outputs is None:
|
|
1524
|
+
encoder_outputs = self.encoder(
|
|
1525
|
+
proj_feats,
|
|
1526
|
+
**kwargs,
|
|
1527
|
+
)
|
|
1528
|
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
|
1529
|
+
elif not isinstance(encoder_outputs, BaseModelOutput):
|
|
1530
|
+
encoder_outputs = BaseModelOutput(
|
|
1531
|
+
last_hidden_state=encoder_outputs[0],
|
|
1532
|
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1533
|
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1534
|
+
)
|
|
1535
|
+
|
|
1536
|
+
# Equivalent to def _get_encoder_input
|
|
1537
|
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
|
|
1538
|
+
sources = []
|
|
1539
|
+
for level, source in enumerate(encoder_outputs.last_hidden_state):
|
|
1540
|
+
sources.append(self.decoder_input_proj[level](source))
|
|
1541
|
+
|
|
1542
|
+
# Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
|
|
1543
|
+
if self.config.num_feature_levels > len(sources):
|
|
1544
|
+
_len_sources = len(sources)
|
|
1545
|
+
sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state)[-1])
|
|
1546
|
+
for i in range(_len_sources + 1, self.config.num_feature_levels):
|
|
1547
|
+
sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1]))
|
|
1548
|
+
|
|
1549
|
+
# Prepare encoder inputs (by flattening)
|
|
1550
|
+
source_flatten = []
|
|
1551
|
+
spatial_shapes_list = []
|
|
1552
|
+
spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long)
|
|
1553
|
+
for level, source in enumerate(sources):
|
|
1554
|
+
height, width = source.shape[-2:]
|
|
1555
|
+
spatial_shapes[level, 0] = height
|
|
1556
|
+
spatial_shapes[level, 1] = width
|
|
1557
|
+
spatial_shapes_list.append((height, width))
|
|
1558
|
+
source = source.flatten(2).transpose(1, 2)
|
|
1559
|
+
source_flatten.append(source)
|
|
1560
|
+
source_flatten = torch.cat(source_flatten, 1)
|
|
1561
|
+
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
|
1562
|
+
|
|
1563
|
+
# prepare denoising training
|
|
1564
|
+
if self.training and self.config.num_denoising > 0 and labels is not None:
|
|
1565
|
+
(
|
|
1566
|
+
denoising_class,
|
|
1567
|
+
denoising_bbox_unact,
|
|
1568
|
+
attention_mask,
|
|
1569
|
+
denoising_meta_values,
|
|
1570
|
+
) = get_contrastive_denoising_training_group(
|
|
1571
|
+
targets=labels,
|
|
1572
|
+
num_classes=self.config.num_labels,
|
|
1573
|
+
num_queries=self.config.num_queries,
|
|
1574
|
+
class_embed=self.denoising_class_embed,
|
|
1575
|
+
num_denoising_queries=self.config.num_denoising,
|
|
1576
|
+
label_noise_ratio=self.config.label_noise_ratio,
|
|
1577
|
+
box_noise_scale=self.config.box_noise_scale,
|
|
1578
|
+
)
|
|
1579
|
+
else:
|
|
1580
|
+
denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None
|
|
1581
|
+
|
|
1582
|
+
batch_size = len(source_flatten)
|
|
1583
|
+
device = source_flatten.device
|
|
1584
|
+
dtype = source_flatten.dtype
|
|
1585
|
+
|
|
1586
|
+
# prepare input for decoder
|
|
1587
|
+
if self.training or self.config.anchor_image_size is None:
|
|
1588
|
+
# Pass spatial_shapes as tuple to make it hashable and make sure
|
|
1589
|
+
# lru_cache is working for generate_anchors()
|
|
1590
|
+
spatial_shapes_tuple = tuple(spatial_shapes_list)
|
|
1591
|
+
anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype)
|
|
1592
|
+
else:
|
|
1593
|
+
anchors, valid_mask = self.anchors, self.valid_mask
|
|
1594
|
+
anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype)
|
|
1595
|
+
|
|
1596
|
+
# use the valid_mask to selectively retain values in the feature map where the mask is `True`
|
|
1597
|
+
memory = valid_mask.to(source_flatten.dtype) * source_flatten
|
|
1598
|
+
|
|
1599
|
+
output_memory = self.enc_output(memory)
|
|
1600
|
+
|
|
1601
|
+
enc_outputs_class = self.enc_score_head(output_memory)
|
|
1602
|
+
enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors
|
|
1603
|
+
|
|
1604
|
+
_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1)
|
|
1605
|
+
|
|
1606
|
+
reference_points_unact = enc_outputs_coord_logits.gather(
|
|
1607
|
+
dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1])
|
|
1608
|
+
)
|
|
1609
|
+
|
|
1610
|
+
enc_topk_bboxes = F.sigmoid(reference_points_unact)
|
|
1611
|
+
if denoising_bbox_unact is not None:
|
|
1612
|
+
reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
|
|
1613
|
+
|
|
1614
|
+
enc_topk_logits = enc_outputs_class.gather(
|
|
1615
|
+
dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
|
|
1616
|
+
)
|
|
1617
|
+
|
|
1618
|
+
# extract region features
|
|
1619
|
+
if self.config.learn_initial_query:
|
|
1620
|
+
target = self.weight_embedding.tile([batch_size, 1, 1])
|
|
1621
|
+
else:
|
|
1622
|
+
target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
|
|
1623
|
+
target = target.detach()
|
|
1624
|
+
|
|
1625
|
+
if denoising_class is not None:
|
|
1626
|
+
target = torch.concat([denoising_class, target], 1)
|
|
1627
|
+
|
|
1628
|
+
init_reference_points = reference_points_unact.detach()
|
|
1629
|
+
|
|
1630
|
+
# decoder
|
|
1631
|
+
decoder_outputs = self.decoder(
|
|
1632
|
+
inputs_embeds=target,
|
|
1633
|
+
encoder_hidden_states=source_flatten,
|
|
1634
|
+
encoder_attention_mask=attention_mask,
|
|
1635
|
+
reference_points=init_reference_points,
|
|
1636
|
+
spatial_shapes=spatial_shapes,
|
|
1637
|
+
spatial_shapes_list=spatial_shapes_list,
|
|
1638
|
+
level_start_index=level_start_index,
|
|
1639
|
+
**kwargs,
|
|
1640
|
+
)
|
|
1641
|
+
|
|
1642
|
+
return RTDetrModelOutput(
|
|
1643
|
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
1644
|
+
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
|
|
1645
|
+
intermediate_logits=decoder_outputs.intermediate_logits,
|
|
1646
|
+
intermediate_reference_points=decoder_outputs.intermediate_reference_points,
|
|
1647
|
+
intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners,
|
|
1648
|
+
initial_reference_points=decoder_outputs.initial_reference_points,
|
|
1649
|
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
1650
|
+
decoder_attentions=decoder_outputs.attentions,
|
|
1651
|
+
cross_attentions=decoder_outputs.cross_attentions,
|
|
1652
|
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
1653
|
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
1654
|
+
encoder_attentions=encoder_outputs.attentions,
|
|
1655
|
+
init_reference_points=init_reference_points,
|
|
1656
|
+
enc_topk_logits=enc_topk_logits,
|
|
1657
|
+
enc_topk_bboxes=enc_topk_bboxes,
|
|
1658
|
+
enc_outputs_class=enc_outputs_class,
|
|
1659
|
+
enc_outputs_coord_logits=enc_outputs_coord_logits,
|
|
1660
|
+
denoising_meta_values=denoising_meta_values,
|
|
1661
|
+
)
|
|
1662
|
+
|
|
1663
|
+
|
|
1664
|
+
@auto_docstring(
|
|
1665
|
+
custom_intro="""
|
|
1666
|
+
RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further
|
|
1667
|
+
decoded into scores and classes.
|
|
1668
|
+
"""
|
|
1669
|
+
)
|
|
1670
|
+
class RTDetrForObjectDetection(RTDetrPreTrainedModel):
|
|
1671
|
+
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
|
|
1672
|
+
# We can't initialize the model on meta device as some weights are modified during the initialization
|
|
1673
|
+
_no_split_modules = None
|
|
1674
|
+
|
|
1675
|
+
def __init__(self, config: RTDetrConfig):
|
|
1676
|
+
super().__init__(config)
|
|
1677
|
+
self.model = RTDetrModel(config)
|
|
1678
|
+
num_pred = config.decoder_layers
|
|
1679
|
+
self.model.decoder.class_embed = nn.ModuleList(
|
|
1680
|
+
[torch.nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]
|
|
1681
|
+
)
|
|
1682
|
+
self.model.decoder.bbox_embed = nn.ModuleList(
|
|
1683
|
+
[RTDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) for _ in range(num_pred)]
|
|
1684
|
+
)
|
|
1685
|
+
# if two-stage, the last class_embed and bbox_embed is for region proposal generation
|
|
1686
|
+
self.post_init()
|
|
1687
|
+
|
|
1688
|
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
|
1689
|
+
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
|
|
1690
|
+
|
|
1691
|
+
@auto_docstring
|
|
1692
|
+
@can_return_tuple
|
|
1693
|
+
def forward(
|
|
1694
|
+
self,
|
|
1695
|
+
pixel_values: torch.FloatTensor,
|
|
1696
|
+
pixel_mask: torch.LongTensor | None = None,
|
|
1697
|
+
encoder_outputs: torch.FloatTensor | None = None,
|
|
1698
|
+
inputs_embeds: torch.FloatTensor | None = None,
|
|
1699
|
+
labels: list[dict] | None = None,
|
|
1700
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1701
|
+
) -> tuple[torch.FloatTensor] | RTDetrObjectDetectionOutput:
|
|
1702
|
+
r"""
|
|
1703
|
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1704
|
+
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
1705
|
+
can choose to directly pass a flattened representation of an image.
|
|
1706
|
+
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1707
|
+
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
|
1708
|
+
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
|
1709
|
+
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
|
|
1710
|
+
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
|
|
1711
|
+
|
|
1712
|
+
Examples:
|
|
1713
|
+
|
|
1714
|
+
```python
|
|
1715
|
+
>>> from transformers import RTDetrImageProcessor, RTDetrForObjectDetection
|
|
1716
|
+
>>> from PIL import Image
|
|
1717
|
+
>>> import requests
|
|
1718
|
+
>>> import torch
|
|
1719
|
+
|
|
1720
|
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
1721
|
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
1722
|
+
|
|
1723
|
+
>>> image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
|
|
1724
|
+
>>> model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
|
|
1725
|
+
|
|
1726
|
+
>>> # prepare image for the model
|
|
1727
|
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
|
1728
|
+
|
|
1729
|
+
>>> # forward pass
|
|
1730
|
+
>>> outputs = model(**inputs)
|
|
1731
|
+
|
|
1732
|
+
>>> logits = outputs.logits
|
|
1733
|
+
>>> list(logits.shape)
|
|
1734
|
+
[1, 300, 80]
|
|
1735
|
+
|
|
1736
|
+
>>> boxes = outputs.pred_boxes
|
|
1737
|
+
>>> list(boxes.shape)
|
|
1738
|
+
[1, 300, 4]
|
|
1739
|
+
|
|
1740
|
+
>>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
|
|
1741
|
+
>>> target_sizes = torch.tensor([image.size[::-1]])
|
|
1742
|
+
>>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
|
|
1743
|
+
... 0
|
|
1744
|
+
... ]
|
|
1745
|
+
|
|
1746
|
+
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
|
1747
|
+
... box = [round(i, 2) for i in box.tolist()]
|
|
1748
|
+
... print(
|
|
1749
|
+
... f"Detected {model.config.id2label[label.item()]} with confidence "
|
|
1750
|
+
... f"{round(score.item(), 3)} at location {box}"
|
|
1751
|
+
... )
|
|
1752
|
+
Detected sofa with confidence 0.97 at location [0.14, 0.38, 640.13, 476.21]
|
|
1753
|
+
Detected cat with confidence 0.96 at location [343.38, 24.28, 640.14, 371.5]
|
|
1754
|
+
Detected cat with confidence 0.958 at location [13.23, 54.18, 318.98, 472.22]
|
|
1755
|
+
Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48]
|
|
1756
|
+
Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99]
|
|
1757
|
+
```"""
|
|
1758
|
+
outputs = self.model(
|
|
1759
|
+
pixel_values,
|
|
1760
|
+
pixel_mask=pixel_mask,
|
|
1761
|
+
encoder_outputs=encoder_outputs,
|
|
1762
|
+
inputs_embeds=inputs_embeds,
|
|
1763
|
+
labels=labels,
|
|
1764
|
+
**kwargs,
|
|
1765
|
+
)
|
|
1766
|
+
|
|
1767
|
+
denoising_meta_values = outputs.denoising_meta_values if self.training else None
|
|
1768
|
+
|
|
1769
|
+
outputs_class = outputs.intermediate_logits
|
|
1770
|
+
outputs_coord = outputs.intermediate_reference_points
|
|
1771
|
+
predicted_corners = outputs.intermediate_predicted_corners
|
|
1772
|
+
initial_reference_points = outputs.initial_reference_points
|
|
1773
|
+
|
|
1774
|
+
logits = outputs_class[:, -1]
|
|
1775
|
+
pred_boxes = outputs_coord[:, -1]
|
|
1776
|
+
|
|
1777
|
+
loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
|
|
1778
|
+
if labels is not None:
|
|
1779
|
+
enc_topk_logits = outputs.enc_topk_logits
|
|
1780
|
+
enc_topk_bboxes = outputs.enc_topk_bboxes
|
|
1781
|
+
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
|
1782
|
+
logits,
|
|
1783
|
+
labels,
|
|
1784
|
+
self.device,
|
|
1785
|
+
pred_boxes,
|
|
1786
|
+
self.config,
|
|
1787
|
+
outputs_class,
|
|
1788
|
+
outputs_coord,
|
|
1789
|
+
enc_topk_logits=enc_topk_logits,
|
|
1790
|
+
enc_topk_bboxes=enc_topk_bboxes,
|
|
1791
|
+
denoising_meta_values=denoising_meta_values,
|
|
1792
|
+
predicted_corners=predicted_corners,
|
|
1793
|
+
initial_reference_points=initial_reference_points,
|
|
1794
|
+
**kwargs,
|
|
1795
|
+
)
|
|
1796
|
+
|
|
1797
|
+
return RTDetrObjectDetectionOutput(
|
|
1798
|
+
loss=loss,
|
|
1799
|
+
loss_dict=loss_dict,
|
|
1800
|
+
logits=logits,
|
|
1801
|
+
pred_boxes=pred_boxes,
|
|
1802
|
+
auxiliary_outputs=auxiliary_outputs,
|
|
1803
|
+
last_hidden_state=outputs.last_hidden_state,
|
|
1804
|
+
intermediate_hidden_states=outputs.intermediate_hidden_states,
|
|
1805
|
+
intermediate_logits=outputs.intermediate_logits,
|
|
1806
|
+
intermediate_reference_points=outputs.intermediate_reference_points,
|
|
1807
|
+
intermediate_predicted_corners=outputs.intermediate_predicted_corners,
|
|
1808
|
+
initial_reference_points=outputs.initial_reference_points,
|
|
1809
|
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
|
1810
|
+
decoder_attentions=outputs.decoder_attentions,
|
|
1811
|
+
cross_attentions=outputs.cross_attentions,
|
|
1812
|
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
|
1813
|
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
|
1814
|
+
encoder_attentions=outputs.encoder_attentions,
|
|
1815
|
+
init_reference_points=outputs.init_reference_points,
|
|
1816
|
+
enc_topk_logits=outputs.enc_topk_logits,
|
|
1817
|
+
enc_topk_bboxes=outputs.enc_topk_bboxes,
|
|
1818
|
+
enc_outputs_class=outputs.enc_outputs_class,
|
|
1819
|
+
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
|
|
1820
|
+
denoising_meta_values=outputs.denoising_meta_values,
|
|
1821
|
+
)
|
|
1822
|
+
|
|
1823
|
+
|
|
1824
|
+
__all__ = [
|
|
1825
|
+
"RTDetrImageProcessorFast",
|
|
1826
|
+
"RTDetrForObjectDetection",
|
|
1827
|
+
"RTDetrModel",
|
|
1828
|
+
"RTDetrPreTrainedModel",
|
|
1829
|
+
]
|