transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -18,24 +18,134 @@
|
|
|
18
18
|
# See the License for the specific language governing permissions and
|
|
19
19
|
# limitations under the License.
|
|
20
20
|
import math
|
|
21
|
+
from collections.abc import Callable
|
|
21
22
|
from dataclasses import dataclass
|
|
22
|
-
from typing import Any
|
|
23
23
|
|
|
24
24
|
import torch
|
|
25
|
+
import torch.nn as nn
|
|
25
26
|
import torch.nn.functional as F
|
|
26
|
-
from torch import Tensor
|
|
27
|
+
from torch import Tensor
|
|
27
28
|
|
|
28
29
|
from ... import initialization as init
|
|
29
|
-
from ...activations import ACT2CLS
|
|
30
|
+
from ...activations import ACT2CLS
|
|
31
|
+
from ...backbone_utils import load_backbone
|
|
30
32
|
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
|
31
33
|
from ...modeling_outputs import BaseModelOutput
|
|
32
|
-
from ...modeling_utils import PreTrainedModel
|
|
34
|
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
35
|
+
from ...processing_utils import Unpack
|
|
33
36
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
34
|
-
from ...utils import ModelOutput, auto_docstring,
|
|
35
|
-
from ...utils.
|
|
37
|
+
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int
|
|
38
|
+
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
36
39
|
from .configuration_d_fine import DFineConfig
|
|
37
40
|
|
|
38
41
|
|
|
42
|
+
@dataclass
|
|
43
|
+
@auto_docstring(
|
|
44
|
+
custom_intro="""
|
|
45
|
+
Base class for outputs of the DFineDecoder. This class adds two attributes to
|
|
46
|
+
BaseModelOutputWithCrossAttentions, namely:
|
|
47
|
+
- a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
|
|
48
|
+
- a stacked tensor of intermediate reference points.
|
|
49
|
+
"""
|
|
50
|
+
)
|
|
51
|
+
class DFineDecoderOutput(ModelOutput):
|
|
52
|
+
r"""
|
|
53
|
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
|
54
|
+
Stacked intermediate hidden states (output of each layer of the decoder).
|
|
55
|
+
intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
|
|
56
|
+
Stacked intermediate logits (logits of each layer of the decoder).
|
|
57
|
+
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
|
|
58
|
+
Stacked intermediate reference points (reference points of each layer of the decoder).
|
|
59
|
+
intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
60
|
+
Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
|
|
61
|
+
initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
62
|
+
Stacked initial reference points (initial reference points of each layer of the decoder).
|
|
63
|
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
|
64
|
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
65
|
+
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
|
66
|
+
used to compute the weighted average in the cross-attention heads.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
last_hidden_state: torch.FloatTensor | None = None
|
|
70
|
+
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
71
|
+
intermediate_logits: torch.FloatTensor | None = None
|
|
72
|
+
intermediate_reference_points: torch.FloatTensor | None = None
|
|
73
|
+
intermediate_predicted_corners: torch.FloatTensor | None = None
|
|
74
|
+
initial_reference_points: torch.FloatTensor | None = None
|
|
75
|
+
hidden_states: tuple[torch.FloatTensor] | None = None
|
|
76
|
+
attentions: tuple[torch.FloatTensor] | None = None
|
|
77
|
+
cross_attentions: tuple[torch.FloatTensor] | None = None
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class DFineMLP(nn.Module):
|
|
81
|
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"):
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.num_layers = num_layers
|
|
84
|
+
hidden_dims = [hidden_dim] * (num_layers - 1)
|
|
85
|
+
input_dims = [input_dim] + hidden_dims
|
|
86
|
+
output_dims = hidden_dims + [output_dim]
|
|
87
|
+
self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims))
|
|
88
|
+
self.act = ACT2CLS[act]()
|
|
89
|
+
|
|
90
|
+
def forward(self, stat_features: torch.Tensor) -> torch.Tensor:
|
|
91
|
+
for i, layer in enumerate(self.layers):
|
|
92
|
+
stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features)
|
|
93
|
+
return stat_features
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class DFineGate(nn.Module):
|
|
97
|
+
def __init__(self, d_model: int):
|
|
98
|
+
super().__init__()
|
|
99
|
+
self.gate = nn.Linear(2 * d_model, 2 * d_model)
|
|
100
|
+
self.norm = nn.LayerNorm(d_model)
|
|
101
|
+
|
|
102
|
+
def forward(self, second_residual: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
103
|
+
gate_input = torch.cat([second_residual, hidden_states], dim=-1)
|
|
104
|
+
gates = torch.sigmoid(self.gate(gate_input))
|
|
105
|
+
gate1, gate2 = gates.chunk(2, dim=-1)
|
|
106
|
+
hidden_states = self.norm(gate1 * second_residual + gate2 * hidden_states)
|
|
107
|
+
return hidden_states
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class DFineFrozenBatchNorm2d(nn.Module):
|
|
111
|
+
"""
|
|
112
|
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
113
|
+
|
|
114
|
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
|
|
115
|
+
torchvision.models.resnet[18,34,50,101] produce nans.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def __init__(self, n):
|
|
119
|
+
super().__init__()
|
|
120
|
+
self.register_buffer("weight", torch.ones(n))
|
|
121
|
+
self.register_buffer("bias", torch.zeros(n))
|
|
122
|
+
self.register_buffer("running_mean", torch.zeros(n))
|
|
123
|
+
self.register_buffer("running_var", torch.ones(n))
|
|
124
|
+
|
|
125
|
+
def _load_from_state_dict(
|
|
126
|
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
127
|
+
):
|
|
128
|
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
|
129
|
+
if num_batches_tracked_key in state_dict:
|
|
130
|
+
del state_dict[num_batches_tracked_key]
|
|
131
|
+
|
|
132
|
+
super()._load_from_state_dict(
|
|
133
|
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def forward(self, x):
|
|
137
|
+
# move reshapes to the beginning
|
|
138
|
+
# to make it user-friendly
|
|
139
|
+
weight = self.weight.reshape(1, -1, 1, 1)
|
|
140
|
+
bias = self.bias.reshape(1, -1, 1, 1)
|
|
141
|
+
running_var = self.running_var.reshape(1, -1, 1, 1)
|
|
142
|
+
running_mean = self.running_mean.reshape(1, -1, 1, 1)
|
|
143
|
+
epsilon = 1e-5
|
|
144
|
+
scale = weight * (running_var + epsilon).rsqrt()
|
|
145
|
+
bias = bias - running_mean * scale
|
|
146
|
+
return x * scale + bias
|
|
147
|
+
|
|
148
|
+
|
|
39
149
|
def multi_scale_deformable_attention_v2(
|
|
40
150
|
value: Tensor,
|
|
41
151
|
value_spatial_shapes: Tensor,
|
|
@@ -147,14 +257,15 @@ class DFineMultiscaleDeformableAttention(nn.Module):
|
|
|
147
257
|
encoder_hidden_states=None,
|
|
148
258
|
spatial_shapes=None,
|
|
149
259
|
spatial_shapes_list=None,
|
|
260
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
150
261
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
151
262
|
batch_size, num_queries, _ = hidden_states.shape
|
|
152
263
|
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
|
153
264
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
265
|
+
torch_compilable_check(
|
|
266
|
+
(spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == sequence_length,
|
|
267
|
+
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
|
|
268
|
+
)
|
|
158
269
|
|
|
159
270
|
# Reshape for multi-head attention
|
|
160
271
|
value = encoder_hidden_states.reshape(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
|
|
@@ -201,162 +312,464 @@ class DFineMultiscaleDeformableAttention(nn.Module):
|
|
|
201
312
|
return output, attention_weights
|
|
202
313
|
|
|
203
314
|
|
|
204
|
-
class
|
|
205
|
-
def __init__(
|
|
315
|
+
class DFineConvNormLayer(nn.Module):
|
|
316
|
+
def __init__(
|
|
317
|
+
self,
|
|
318
|
+
config: DFineConfig,
|
|
319
|
+
in_channels: int,
|
|
320
|
+
out_channels: int,
|
|
321
|
+
kernel_size: int,
|
|
322
|
+
stride: int,
|
|
323
|
+
groups: int = 1,
|
|
324
|
+
padding: int | None = None,
|
|
325
|
+
activation: str | None = None,
|
|
326
|
+
):
|
|
206
327
|
super().__init__()
|
|
207
|
-
self.
|
|
208
|
-
|
|
328
|
+
self.conv = nn.Conv2d(
|
|
329
|
+
in_channels,
|
|
330
|
+
out_channels,
|
|
331
|
+
kernel_size,
|
|
332
|
+
stride,
|
|
333
|
+
groups=groups,
|
|
334
|
+
padding=(kernel_size - 1) // 2 if padding is None else padding,
|
|
335
|
+
bias=False,
|
|
336
|
+
)
|
|
337
|
+
self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
|
|
338
|
+
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
|
209
339
|
|
|
210
|
-
def forward(self,
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
return hidden_states
|
|
340
|
+
def forward(self, hidden_state):
|
|
341
|
+
hidden_state = self.conv(hidden_state)
|
|
342
|
+
hidden_state = self.norm(hidden_state)
|
|
343
|
+
hidden_state = self.activation(hidden_state)
|
|
344
|
+
return hidden_state
|
|
216
345
|
|
|
217
346
|
|
|
218
|
-
class
|
|
347
|
+
class DFineRepVggBlock(nn.Module):
|
|
348
|
+
"""
|
|
349
|
+
RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
|
|
219
350
|
"""
|
|
220
|
-
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
221
351
|
|
|
222
|
-
|
|
352
|
+
def __init__(self, config: DFineConfig, in_channels: int, out_channels: int):
|
|
353
|
+
super().__init__()
|
|
354
|
+
|
|
355
|
+
activation = config.activation_function
|
|
356
|
+
hidden_channels = in_channels
|
|
357
|
+
self.conv1 = DFineConvNormLayer(config, hidden_channels, out_channels, 3, 1, padding=1)
|
|
358
|
+
self.conv2 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, padding=0)
|
|
359
|
+
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
|
360
|
+
|
|
361
|
+
def forward(self, x):
|
|
362
|
+
y = self.conv1(x) + self.conv2(x)
|
|
363
|
+
return self.activation(y)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
class DFineCSPRepLayer(nn.Module):
|
|
367
|
+
"""
|
|
368
|
+
Cross Stage Partial (CSP) network layer with RepVGG blocks.
|
|
223
369
|
"""
|
|
224
370
|
|
|
225
371
|
def __init__(
|
|
226
|
-
self,
|
|
227
|
-
embed_dim: int,
|
|
228
|
-
num_heads: int,
|
|
229
|
-
dropout: float = 0.0,
|
|
230
|
-
bias: bool = True,
|
|
372
|
+
self, config: DFineConfig, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0
|
|
231
373
|
):
|
|
232
374
|
super().__init__()
|
|
233
|
-
|
|
234
|
-
self.num_heads = num_heads
|
|
235
|
-
self.dropout = dropout
|
|
236
|
-
self.head_dim = embed_dim // num_heads
|
|
237
|
-
if self.head_dim * num_heads != self.embed_dim:
|
|
238
|
-
raise ValueError(
|
|
239
|
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
240
|
-
f" {num_heads})."
|
|
241
|
-
)
|
|
242
|
-
self.scaling = self.head_dim**-0.5
|
|
375
|
+
activation = config.activation_function
|
|
243
376
|
|
|
244
|
-
|
|
245
|
-
self.
|
|
246
|
-
self.
|
|
247
|
-
self.
|
|
377
|
+
hidden_channels = int(out_channels * expansion)
|
|
378
|
+
self.conv1 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
|
379
|
+
self.conv2 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
|
380
|
+
self.bottlenecks = nn.ModuleList(
|
|
381
|
+
[DFineRepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)]
|
|
382
|
+
)
|
|
383
|
+
if hidden_channels != out_channels:
|
|
384
|
+
self.conv3 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
|
|
385
|
+
else:
|
|
386
|
+
self.conv3 = nn.Identity()
|
|
248
387
|
|
|
249
|
-
def
|
|
250
|
-
|
|
388
|
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
|
389
|
+
hidden_state_1 = self.conv1(hidden_state)
|
|
390
|
+
for bottleneck in self.bottlenecks:
|
|
391
|
+
hidden_state_1 = bottleneck(hidden_state_1)
|
|
392
|
+
hidden_state_2 = self.conv2(hidden_state)
|
|
393
|
+
hidden_state_3 = self.conv3(hidden_state_1 + hidden_state_2)
|
|
394
|
+
return hidden_state_3
|
|
251
395
|
|
|
252
|
-
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Tensor | None):
|
|
253
|
-
return tensor if position_embeddings is None else tensor + position_embeddings
|
|
254
396
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
397
|
+
class DFineRepNCSPELAN4(nn.Module):
|
|
398
|
+
def __init__(self, config: DFineConfig, act: str = "silu", numb_blocks: int = 3):
|
|
399
|
+
super().__init__()
|
|
400
|
+
conv1_dim = config.encoder_hidden_dim * 2
|
|
401
|
+
conv2_dim = config.encoder_hidden_dim
|
|
402
|
+
conv3_dim = config.encoder_hidden_dim * 2
|
|
403
|
+
conv4_dim = round(config.hidden_expansion * config.encoder_hidden_dim // 2)
|
|
404
|
+
self.conv_dim = conv3_dim // 2
|
|
405
|
+
self.conv1 = DFineConvNormLayer(config, conv1_dim, conv3_dim, 1, 1, activation=act)
|
|
406
|
+
self.csp_rep1 = DFineCSPRepLayer(config, conv3_dim // 2, conv4_dim, num_blocks=numb_blocks)
|
|
407
|
+
self.conv2 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
|
|
408
|
+
self.csp_rep2 = DFineCSPRepLayer(config, conv4_dim, conv4_dim, num_blocks=numb_blocks)
|
|
409
|
+
self.conv3 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
|
|
410
|
+
self.conv4 = DFineConvNormLayer(config, conv3_dim + (2 * conv4_dim), conv2_dim, 1, 1, activation=act)
|
|
263
411
|
|
|
264
|
-
|
|
265
|
-
#
|
|
266
|
-
|
|
267
|
-
hidden_states_original = hidden_states
|
|
268
|
-
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
|
412
|
+
def forward(self, input_features: torch.Tensor) -> torch.Tensor:
|
|
413
|
+
# Split initial features into two branches after first convolution
|
|
414
|
+
split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1))
|
|
269
415
|
|
|
270
|
-
#
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
416
|
+
# Process branches sequentially
|
|
417
|
+
branch1 = self.csp_rep1(split_features[-1])
|
|
418
|
+
branch1 = self.conv2(branch1)
|
|
419
|
+
branch2 = self.csp_rep2(branch1)
|
|
420
|
+
branch2 = self.conv3(branch2)
|
|
274
421
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
422
|
+
split_features.extend([branch1, branch2])
|
|
423
|
+
merged_features = torch.cat(split_features, 1)
|
|
424
|
+
merged_features = self.conv4(merged_features)
|
|
425
|
+
return merged_features
|
|
279
426
|
|
|
280
|
-
source_len = key_states.size(1)
|
|
281
427
|
|
|
282
|
-
|
|
428
|
+
class DFineSCDown(nn.Module):
|
|
429
|
+
def __init__(self, config: DFineConfig, kernel_size: int, stride: int):
|
|
430
|
+
super().__init__()
|
|
431
|
+
self.conv1 = DFineConvNormLayer(config, config.encoder_hidden_dim, config.encoder_hidden_dim, 1, 1)
|
|
432
|
+
self.conv2 = DFineConvNormLayer(
|
|
433
|
+
config,
|
|
434
|
+
config.encoder_hidden_dim,
|
|
435
|
+
config.encoder_hidden_dim,
|
|
436
|
+
kernel_size,
|
|
437
|
+
stride,
|
|
438
|
+
config.encoder_hidden_dim,
|
|
439
|
+
)
|
|
283
440
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
)
|
|
441
|
+
def forward(self, input_features: torch.Tensor) -> torch.Tensor:
|
|
442
|
+
input_features = self.conv1(input_features)
|
|
443
|
+
input_features = self.conv2(input_features)
|
|
444
|
+
return input_features
|
|
289
445
|
|
|
290
|
-
# expand attention_mask
|
|
291
|
-
if attention_mask is not None:
|
|
292
|
-
# [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
293
|
-
attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size())
|
|
294
446
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
309
|
-
|
|
310
|
-
if output_attentions:
|
|
311
|
-
# this operation is a bit awkward, but it's required to
|
|
312
|
-
# make sure that attn_weights keeps its gradient.
|
|
313
|
-
# In order to do so, attn_weights have to reshaped
|
|
314
|
-
# twice and have to be reused in the following
|
|
315
|
-
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
|
316
|
-
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
|
317
|
-
else:
|
|
318
|
-
attn_weights_reshaped = None
|
|
447
|
+
def eager_attention_forward(
|
|
448
|
+
module: nn.Module,
|
|
449
|
+
query: torch.Tensor,
|
|
450
|
+
key: torch.Tensor,
|
|
451
|
+
value: torch.Tensor,
|
|
452
|
+
attention_mask: torch.Tensor | None,
|
|
453
|
+
scaling: float | None = None,
|
|
454
|
+
dropout: float = 0.0,
|
|
455
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
456
|
+
):
|
|
457
|
+
if scaling is None:
|
|
458
|
+
scaling = query.size(-1) ** -0.5
|
|
319
459
|
|
|
320
|
-
|
|
460
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
461
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
321
462
|
|
|
322
|
-
|
|
463
|
+
if attention_mask is not None:
|
|
464
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
465
|
+
attn_weights = attn_weights + attention_mask
|
|
323
466
|
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
|
|
327
|
-
f" {attn_output.size()}"
|
|
328
|
-
)
|
|
467
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
468
|
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
329
469
|
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
|
|
470
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
471
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
333
472
|
|
|
334
|
-
|
|
473
|
+
return attn_output, attn_weights
|
|
335
474
|
|
|
336
|
-
return attn_output, attn_weights_reshaped
|
|
337
475
|
|
|
476
|
+
class DFineSelfAttention(nn.Module):
|
|
477
|
+
"""
|
|
478
|
+
Multi-headed self-attention from 'Attention Is All You Need' paper.
|
|
338
479
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
self
|
|
344
|
-
|
|
345
|
-
|
|
480
|
+
In D_FINE, position embeddings are added to both queries and keys (but not values) in self-attention.
|
|
481
|
+
"""
|
|
482
|
+
|
|
483
|
+
def __init__(
|
|
484
|
+
self,
|
|
485
|
+
config: DFineConfig,
|
|
486
|
+
hidden_size: int,
|
|
487
|
+
num_attention_heads: int,
|
|
488
|
+
dropout: float = 0.0,
|
|
489
|
+
bias: bool = True,
|
|
490
|
+
):
|
|
491
|
+
super().__init__()
|
|
492
|
+
self.config = config
|
|
493
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
494
|
+
self.scaling = self.head_dim**-0.5
|
|
495
|
+
self.attention_dropout = dropout
|
|
496
|
+
self.is_causal = False
|
|
497
|
+
|
|
498
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
499
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
500
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
501
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
502
|
+
|
|
503
|
+
def forward(
|
|
504
|
+
self,
|
|
505
|
+
hidden_states: torch.Tensor,
|
|
506
|
+
attention_mask: torch.Tensor | None = None,
|
|
507
|
+
position_embeddings: torch.Tensor | None = None,
|
|
508
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
509
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
510
|
+
"""
|
|
511
|
+
Position embeddings are added to both queries and keys (but not values).
|
|
512
|
+
"""
|
|
513
|
+
input_shape = hidden_states.shape[:-1]
|
|
514
|
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
515
|
+
|
|
516
|
+
query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
|
|
517
|
+
|
|
518
|
+
query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
519
|
+
key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
520
|
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
521
|
+
|
|
522
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
523
|
+
self.config._attn_implementation, eager_attention_forward
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
attn_output, attn_weights = attention_interface(
|
|
527
|
+
self,
|
|
528
|
+
query_states,
|
|
529
|
+
key_states,
|
|
530
|
+
value_states,
|
|
531
|
+
attention_mask,
|
|
532
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
533
|
+
scaling=self.scaling,
|
|
534
|
+
**kwargs,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
538
|
+
attn_output = self.o_proj(attn_output)
|
|
539
|
+
return attn_output, attn_weights
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
class DFineEncoderLayer(nn.Module):
|
|
543
|
+
def __init__(self, config: DFineConfig):
|
|
544
|
+
super().__init__()
|
|
545
|
+
self.normalize_before = config.normalize_before
|
|
546
|
+
self.hidden_size = config.encoder_hidden_dim
|
|
547
|
+
|
|
548
|
+
# self-attention
|
|
549
|
+
self.self_attn = DFineSelfAttention(
|
|
550
|
+
config=config,
|
|
551
|
+
hidden_size=self.hidden_size,
|
|
552
|
+
num_attention_heads=config.num_attention_heads,
|
|
553
|
+
dropout=config.dropout,
|
|
554
|
+
)
|
|
555
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
556
|
+
self.dropout = config.dropout
|
|
557
|
+
self.mlp = DFineMLP(
|
|
558
|
+
self.hidden_size, config.encoder_ffn_dim, self.hidden_size, 2, config.encoder_activation_function
|
|
559
|
+
)
|
|
560
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
561
|
+
|
|
562
|
+
def forward(
|
|
563
|
+
self,
|
|
564
|
+
hidden_states: torch.Tensor,
|
|
565
|
+
attention_mask: torch.Tensor,
|
|
566
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
567
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
568
|
+
) -> torch.Tensor:
|
|
569
|
+
"""
|
|
570
|
+
Args:
|
|
571
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
572
|
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
573
|
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
574
|
+
values.
|
|
575
|
+
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
576
|
+
Spatial position embeddings (2D positional encodings of image locations), to be added to both
|
|
577
|
+
the queries and keys in self-attention (but not to values).
|
|
578
|
+
"""
|
|
579
|
+
residual = hidden_states
|
|
580
|
+
if self.normalize_before:
|
|
581
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
582
|
+
|
|
583
|
+
hidden_states, _ = self.self_attn(
|
|
584
|
+
hidden_states=hidden_states,
|
|
585
|
+
attention_mask=attention_mask,
|
|
586
|
+
position_embeddings=spatial_position_embeddings,
|
|
587
|
+
**kwargs,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
591
|
+
hidden_states = residual + hidden_states
|
|
592
|
+
if not self.normalize_before:
|
|
593
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
594
|
+
|
|
595
|
+
if self.normalize_before:
|
|
596
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
|
597
|
+
residual = hidden_states
|
|
598
|
+
|
|
599
|
+
hidden_states = self.mlp(hidden_states)
|
|
600
|
+
|
|
601
|
+
hidden_states = residual + hidden_states
|
|
602
|
+
if not self.normalize_before:
|
|
603
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
|
604
|
+
|
|
605
|
+
if self.training:
|
|
606
|
+
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
|
607
|
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
608
|
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
609
|
+
|
|
610
|
+
return hidden_states
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
class DFineSinePositionEmbedding(nn.Module):
|
|
614
|
+
"""
|
|
615
|
+
2D sinusoidal position embedding used in RT-DETR hybrid encoder.
|
|
616
|
+
"""
|
|
617
|
+
|
|
618
|
+
def __init__(self, embed_dim: int = 256, temperature: int = 10000):
|
|
619
|
+
super().__init__()
|
|
620
|
+
self.embed_dim = embed_dim
|
|
621
|
+
self.temperature = temperature
|
|
622
|
+
|
|
623
|
+
@compile_compatible_method_lru_cache(maxsize=32)
|
|
624
|
+
def forward(
|
|
625
|
+
self,
|
|
626
|
+
width: int,
|
|
627
|
+
height: int,
|
|
628
|
+
device: torch.device | str,
|
|
629
|
+
dtype: torch.dtype,
|
|
630
|
+
) -> torch.Tensor:
|
|
631
|
+
"""
|
|
632
|
+
Generate 2D sinusoidal position embeddings.
|
|
633
|
+
|
|
634
|
+
Returns:
|
|
635
|
+
Position embeddings of shape (1, height*width, embed_dim)
|
|
636
|
+
"""
|
|
637
|
+
grid_w = torch.arange(torch_int(width), device=device).to(dtype)
|
|
638
|
+
grid_h = torch.arange(torch_int(height), device=device).to(dtype)
|
|
639
|
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
|
640
|
+
if self.embed_dim % 4 != 0:
|
|
641
|
+
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
|
|
642
|
+
pos_dim = self.embed_dim // 4
|
|
643
|
+
omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
|
|
644
|
+
omega = 1.0 / (self.temperature**omega)
|
|
645
|
+
|
|
646
|
+
out_w = grid_w.flatten()[..., None] @ omega[None]
|
|
647
|
+
out_h = grid_h.flatten()[..., None] @ omega[None]
|
|
648
|
+
|
|
649
|
+
return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
class DFineAIFILayer(nn.Module):
|
|
653
|
+
"""
|
|
654
|
+
AIFI (Attention-based Intra-scale Feature Interaction) layer used in RT-DETR hybrid encoder.
|
|
655
|
+
"""
|
|
656
|
+
|
|
657
|
+
def __init__(self, config: DFineConfig):
|
|
658
|
+
super().__init__()
|
|
659
|
+
self.config = config
|
|
660
|
+
self.encoder_hidden_dim = config.encoder_hidden_dim
|
|
661
|
+
self.eval_size = config.eval_size
|
|
662
|
+
|
|
663
|
+
self.position_embedding = DFineSinePositionEmbedding(
|
|
664
|
+
embed_dim=self.encoder_hidden_dim,
|
|
665
|
+
temperature=config.positional_encoding_temperature,
|
|
666
|
+
)
|
|
667
|
+
self.layers = nn.ModuleList([DFineEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
668
|
+
|
|
669
|
+
def forward(
|
|
670
|
+
self,
|
|
671
|
+
hidden_states: torch.Tensor,
|
|
672
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
673
|
+
) -> torch.Tensor:
|
|
674
|
+
"""
|
|
675
|
+
Args:
|
|
676
|
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
|
|
677
|
+
Feature map to process.
|
|
678
|
+
"""
|
|
679
|
+
batch_size = hidden_states.shape[0]
|
|
680
|
+
height, width = hidden_states.shape[2:]
|
|
681
|
+
|
|
682
|
+
hidden_states = hidden_states.flatten(2).permute(0, 2, 1)
|
|
683
|
+
|
|
684
|
+
if self.training or self.eval_size is None:
|
|
685
|
+
pos_embed = self.position_embedding(
|
|
686
|
+
width=width,
|
|
687
|
+
height=height,
|
|
688
|
+
device=hidden_states.device,
|
|
689
|
+
dtype=hidden_states.dtype,
|
|
690
|
+
)
|
|
691
|
+
else:
|
|
692
|
+
pos_embed = None
|
|
693
|
+
|
|
694
|
+
for layer in self.layers:
|
|
695
|
+
hidden_states = layer(
|
|
696
|
+
hidden_states,
|
|
697
|
+
attention_mask=None,
|
|
698
|
+
spatial_position_embeddings=pos_embed,
|
|
699
|
+
**kwargs,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
hidden_states = (
|
|
703
|
+
hidden_states.permute(0, 2, 1).reshape(batch_size, self.encoder_hidden_dim, height, width).contiguous()
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
return hidden_states
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
class DFineIntegral(nn.Module):
|
|
710
|
+
"""
|
|
711
|
+
A static layer that calculates integral results from a distribution.
|
|
712
|
+
|
|
713
|
+
This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`,
|
|
714
|
+
where Pr(n) is the softmax probability vector representing the discrete
|
|
715
|
+
distribution, and W(n) is the non-uniform Weighting Function.
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
max_num_bins (int): Max number of the discrete bins. Default is 32.
|
|
719
|
+
It can be adjusted based on the dataset or task requirements.
|
|
720
|
+
"""
|
|
721
|
+
|
|
722
|
+
def __init__(self, config: DFineConfig):
|
|
723
|
+
super().__init__()
|
|
724
|
+
self.max_num_bins = config.max_num_bins
|
|
725
|
+
|
|
726
|
+
def forward(self, pred_corners: torch.Tensor, project: torch.Tensor) -> torch.Tensor:
|
|
727
|
+
batch_size, num_queries, _ = pred_corners.shape
|
|
728
|
+
pred_corners = F.softmax(pred_corners.reshape(-1, self.max_num_bins + 1), dim=1)
|
|
729
|
+
pred_corners = F.linear(pred_corners, project.to(pred_corners.device)).reshape(-1, 4)
|
|
730
|
+
pred_corners = pred_corners.reshape(batch_size, num_queries, -1)
|
|
731
|
+
return pred_corners
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
class DFineLQE(nn.Module):
|
|
735
|
+
def __init__(self, config: DFineConfig):
|
|
736
|
+
super().__init__()
|
|
737
|
+
self.top_prob_values = config.top_prob_values
|
|
738
|
+
self.max_num_bins = config.max_num_bins
|
|
739
|
+
self.reg_conf = DFineMLP(4 * (self.top_prob_values + 1), config.lqe_hidden_dim, 1, config.lqe_layers)
|
|
740
|
+
|
|
741
|
+
def forward(self, scores: torch.Tensor, pred_corners: torch.Tensor) -> torch.Tensor:
|
|
742
|
+
batch_size, length, _ = pred_corners.size()
|
|
743
|
+
prob = F.softmax(pred_corners.reshape(batch_size, length, 4, self.max_num_bins + 1), dim=-1)
|
|
744
|
+
prob_topk, _ = prob.topk(self.top_prob_values, dim=-1)
|
|
745
|
+
stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1)
|
|
746
|
+
quality_score = self.reg_conf(stat.reshape(batch_size, length, -1))
|
|
747
|
+
scores = scores + quality_score
|
|
748
|
+
return scores
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
class DFineDecoderLayer(nn.Module):
|
|
752
|
+
def __init__(self, config: DFineConfig):
|
|
753
|
+
super().__init__()
|
|
754
|
+
self.hidden_size = config.d_model
|
|
755
|
+
|
|
756
|
+
# self-attention
|
|
757
|
+
self.self_attn = DFineSelfAttention(
|
|
758
|
+
config=config,
|
|
759
|
+
hidden_size=self.hidden_size,
|
|
760
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
346
761
|
dropout=config.attention_dropout,
|
|
347
762
|
)
|
|
348
763
|
self.dropout = config.dropout
|
|
349
|
-
self.activation_fn = ACT2FN[config.decoder_activation_function]
|
|
350
|
-
self.activation_dropout = config.activation_dropout
|
|
351
764
|
|
|
352
|
-
self.self_attn_layer_norm = nn.LayerNorm(
|
|
765
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
353
766
|
|
|
354
767
|
# override the encoder attention module with d-fine version
|
|
355
768
|
self.encoder_attn = DFineMultiscaleDeformableAttention(config=config)
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
self.final_layer_norm = nn.LayerNorm(
|
|
769
|
+
self.mlp = DFineMLP(
|
|
770
|
+
self.hidden_size, config.decoder_ffn_dim, self.hidden_size, 2, config.decoder_activation_function
|
|
771
|
+
)
|
|
772
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
|
360
773
|
# gate
|
|
361
774
|
self.gateway = DFineGate(config.d_model)
|
|
362
775
|
|
|
@@ -369,14 +782,15 @@ class DFineDecoderLayer(nn.Module):
|
|
|
369
782
|
spatial_shapes_list=None,
|
|
370
783
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
371
784
|
encoder_attention_mask: torch.Tensor | None = None,
|
|
372
|
-
|
|
373
|
-
) ->
|
|
785
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
786
|
+
) -> torch.Tensor:
|
|
374
787
|
"""
|
|
375
788
|
Args:
|
|
376
789
|
hidden_states (`torch.FloatTensor`):
|
|
377
|
-
Input to the layer of shape `(
|
|
378
|
-
|
|
379
|
-
Position embeddings
|
|
790
|
+
Input to the layer of shape `(batch, seq_len, hidden_size)`.
|
|
791
|
+
object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
792
|
+
Position embeddings for the object query slots. These are added to both queries and keys
|
|
793
|
+
in the self-attention layer (not values).
|
|
380
794
|
reference_points (`torch.FloatTensor`, *optional*):
|
|
381
795
|
Reference points.
|
|
382
796
|
spatial_shapes (`torch.LongTensor`, *optional*):
|
|
@@ -384,55 +798,65 @@ class DFineDecoderLayer(nn.Module):
|
|
|
384
798
|
level_start_index (`torch.LongTensor`, *optional*):
|
|
385
799
|
Level start index.
|
|
386
800
|
encoder_hidden_states (`torch.FloatTensor`):
|
|
387
|
-
cross attention input to the layer of shape `(
|
|
801
|
+
cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
388
802
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
389
803
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
390
804
|
values.
|
|
391
|
-
output_attentions (`bool`, *optional*):
|
|
392
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
393
|
-
returned tensors for more detail.
|
|
394
805
|
"""
|
|
806
|
+
residual = hidden_states
|
|
807
|
+
|
|
395
808
|
# Self Attention
|
|
396
|
-
|
|
809
|
+
hidden_states, _ = self.self_attn(
|
|
397
810
|
hidden_states=hidden_states,
|
|
398
811
|
attention_mask=encoder_attention_mask,
|
|
399
812
|
position_embeddings=position_embeddings,
|
|
400
|
-
|
|
813
|
+
**kwargs,
|
|
401
814
|
)
|
|
402
815
|
|
|
403
|
-
|
|
404
|
-
hidden_states =
|
|
816
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
817
|
+
hidden_states = residual + hidden_states
|
|
405
818
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
819
|
+
|
|
406
820
|
residual = hidden_states
|
|
407
821
|
|
|
408
822
|
# Cross-Attention
|
|
409
|
-
cross_attn_weights = None
|
|
410
823
|
hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings
|
|
411
|
-
|
|
824
|
+
hidden_states, _ = self.encoder_attn(
|
|
412
825
|
hidden_states=hidden_states,
|
|
413
826
|
encoder_hidden_states=encoder_hidden_states,
|
|
414
827
|
reference_points=reference_points,
|
|
415
828
|
spatial_shapes=spatial_shapes,
|
|
416
829
|
spatial_shapes_list=spatial_shapes_list,
|
|
417
830
|
)
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
hidden_states = self.gateway(residual, hidden_states_2)
|
|
831
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
832
|
+
hidden_states = self.gateway(residual, hidden_states)
|
|
421
833
|
|
|
422
834
|
# Fully Connected
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
|
|
427
|
-
hidden_states = hidden_states + hidden_states_2
|
|
835
|
+
residual = hidden_states
|
|
836
|
+
hidden_states = self.mlp(hidden_states)
|
|
837
|
+
hidden_states = residual + hidden_states
|
|
428
838
|
hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504))
|
|
429
839
|
|
|
430
|
-
|
|
840
|
+
return hidden_states
|
|
841
|
+
|
|
842
|
+
|
|
843
|
+
class DFineMLPPredictionHead(nn.Module):
|
|
844
|
+
"""
|
|
845
|
+
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
|
846
|
+
height and width of a bounding box w.r.t. an image.
|
|
847
|
+
|
|
848
|
+
"""
|
|
431
849
|
|
|
432
|
-
|
|
433
|
-
|
|
850
|
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
851
|
+
super().__init__()
|
|
852
|
+
self.num_layers = num_layers
|
|
853
|
+
h = [hidden_dim] * (num_layers - 1)
|
|
854
|
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
|
434
855
|
|
|
435
|
-
|
|
856
|
+
def forward(self, x):
|
|
857
|
+
for i, layer in enumerate(self.layers):
|
|
858
|
+
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
859
|
+
return x
|
|
436
860
|
|
|
437
861
|
|
|
438
862
|
@auto_docstring
|
|
@@ -442,6 +866,10 @@ class DFinePreTrainedModel(PreTrainedModel):
|
|
|
442
866
|
main_input_name = "pixel_values"
|
|
443
867
|
input_modalities = ("image",)
|
|
444
868
|
_no_split_modules = [r"DFineHybridEncoder", r"DFineDecoderLayer"]
|
|
869
|
+
_supports_sdpa = True
|
|
870
|
+
_supports_flash_attn = True
|
|
871
|
+
_supports_attention_backend = True
|
|
872
|
+
_supports_flex_attn = True
|
|
445
873
|
|
|
446
874
|
@torch.no_grad()
|
|
447
875
|
def _init_weights(self, module):
|
|
@@ -519,67 +947,102 @@ class DFinePreTrainedModel(PreTrainedModel):
|
|
|
519
947
|
init.xavier_uniform_(module.denoising_class_embed.weight)
|
|
520
948
|
|
|
521
949
|
|
|
522
|
-
class
|
|
950
|
+
class DFineHybridEncoder(DFinePreTrainedModel):
|
|
523
951
|
"""
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
where Pr(n) is the softmax probability vector representing the discrete
|
|
528
|
-
distribution, and W(n) is the non-uniform Weighting Function.
|
|
952
|
+
Hybrid encoder consisting of AIFI (Attention-based Intra-scale Feature Interaction) layers,
|
|
953
|
+
a top-down Feature Pyramid Network (FPN) and a bottom-up Path Aggregation Network (PAN).
|
|
954
|
+
More details on the paper: https://huggingface.co/papers/2304.08069
|
|
529
955
|
|
|
530
956
|
Args:
|
|
531
|
-
|
|
532
|
-
It can be adjusted based on the dataset or task requirements.
|
|
957
|
+
config: DFineConfig
|
|
533
958
|
"""
|
|
534
959
|
|
|
960
|
+
_can_record_outputs = {
|
|
961
|
+
"hidden_states": DFineAIFILayer,
|
|
962
|
+
"attentions": DFineSelfAttention,
|
|
963
|
+
}
|
|
964
|
+
|
|
535
965
|
def __init__(self, config: DFineConfig):
|
|
536
|
-
super().__init__()
|
|
537
|
-
self.
|
|
966
|
+
super().__init__(config)
|
|
967
|
+
self.config = config
|
|
968
|
+
self.in_channels = config.encoder_in_channels
|
|
969
|
+
self.num_fpn_stages = len(self.in_channels) - 1
|
|
970
|
+
self.feat_strides = config.feat_strides
|
|
971
|
+
self.encoder_hidden_dim = config.encoder_hidden_dim
|
|
972
|
+
self.encode_proj_layers = config.encode_proj_layers
|
|
973
|
+
self.positional_encoding_temperature = config.positional_encoding_temperature
|
|
974
|
+
self.eval_size = config.eval_size
|
|
975
|
+
self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
|
|
976
|
+
self.out_strides = self.feat_strides
|
|
538
977
|
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
pred_corners = F.softmax(pred_corners.reshape(-1, self.max_num_bins + 1), dim=1)
|
|
542
|
-
pred_corners = F.linear(pred_corners, project.to(pred_corners.device)).reshape(-1, 4)
|
|
543
|
-
pred_corners = pred_corners.reshape(batch_size, num_queries, -1)
|
|
544
|
-
return pred_corners
|
|
978
|
+
# AIFI (Attention-based Intra-scale Feature Interaction) layers
|
|
979
|
+
self.aifi = nn.ModuleList([DFineAIFILayer(config) for _ in range(len(self.encode_proj_layers))])
|
|
545
980
|
|
|
981
|
+
# top-down fpn
|
|
982
|
+
self.lateral_convs = nn.ModuleList()
|
|
983
|
+
self.fpn_blocks = nn.ModuleList()
|
|
984
|
+
for _ in range(len(self.in_channels) - 1, 0, -1):
|
|
985
|
+
lateral_layer = DFineConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1)
|
|
986
|
+
self.lateral_convs.append(lateral_layer)
|
|
987
|
+
num_blocks = round(3 * config.depth_mult)
|
|
988
|
+
fpn_layer = DFineRepNCSPELAN4(config, numb_blocks=num_blocks)
|
|
989
|
+
self.fpn_blocks.append(fpn_layer)
|
|
546
990
|
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
"""
|
|
555
|
-
)
|
|
556
|
-
class DFineDecoderOutput(ModelOutput):
|
|
557
|
-
r"""
|
|
558
|
-
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
|
559
|
-
Stacked intermediate hidden states (output of each layer of the decoder).
|
|
560
|
-
intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
|
|
561
|
-
Stacked intermediate logits (logits of each layer of the decoder).
|
|
562
|
-
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
|
|
563
|
-
Stacked intermediate reference points (reference points of each layer of the decoder).
|
|
564
|
-
intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
565
|
-
Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
|
|
566
|
-
initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
567
|
-
Stacked initial reference points (initial reference points of each layer of the decoder).
|
|
568
|
-
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
|
569
|
-
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
570
|
-
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
|
571
|
-
used to compute the weighted average in the cross-attention heads.
|
|
572
|
-
"""
|
|
991
|
+
# bottom-up pan
|
|
992
|
+
self.downsample_convs = nn.ModuleList()
|
|
993
|
+
self.pan_blocks = nn.ModuleList()
|
|
994
|
+
for _ in range(len(self.in_channels) - 1):
|
|
995
|
+
self.downsample_convs.append(DFineSCDown(config, 3, 2))
|
|
996
|
+
num_blocks = round(3 * config.depth_mult)
|
|
997
|
+
self.pan_blocks.append(DFineRepNCSPELAN4(config, numb_blocks=num_blocks))
|
|
573
998
|
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
999
|
+
self.post_init()
|
|
1000
|
+
|
|
1001
|
+
@check_model_inputs(tie_last_hidden_states=False)
|
|
1002
|
+
def forward(
|
|
1003
|
+
self,
|
|
1004
|
+
inputs_embeds=None,
|
|
1005
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1006
|
+
) -> BaseModelOutput:
|
|
1007
|
+
r"""
|
|
1008
|
+
Args:
|
|
1009
|
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
1010
|
+
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
|
1011
|
+
"""
|
|
1012
|
+
feature_maps = inputs_embeds
|
|
1013
|
+
|
|
1014
|
+
# AIFI: Apply transformer encoder to specified feature levels
|
|
1015
|
+
if self.config.encoder_layers > 0:
|
|
1016
|
+
for i, enc_ind in enumerate(self.encode_proj_layers):
|
|
1017
|
+
feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs)
|
|
1018
|
+
|
|
1019
|
+
# top-down FPN
|
|
1020
|
+
fpn_feature_maps = [feature_maps[-1]]
|
|
1021
|
+
for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
|
|
1022
|
+
backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1]
|
|
1023
|
+
top_fpn_feature_map = fpn_feature_maps[-1]
|
|
1024
|
+
# apply lateral block
|
|
1025
|
+
top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
|
|
1026
|
+
fpn_feature_maps[-1] = top_fpn_feature_map
|
|
1027
|
+
# apply fpn block
|
|
1028
|
+
top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
|
|
1029
|
+
fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
|
|
1030
|
+
new_fpn_feature_map = fpn_block(fused_feature_map)
|
|
1031
|
+
fpn_feature_maps.append(new_fpn_feature_map)
|
|
1032
|
+
|
|
1033
|
+
fpn_feature_maps.reverse()
|
|
1034
|
+
|
|
1035
|
+
# bottom-up PAN
|
|
1036
|
+
pan_feature_maps = [fpn_feature_maps[0]]
|
|
1037
|
+
for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
|
|
1038
|
+
top_pan_feature_map = pan_feature_maps[-1]
|
|
1039
|
+
fpn_feature_map = fpn_feature_maps[idx + 1]
|
|
1040
|
+
downsampled_feature_map = downsample_conv(top_pan_feature_map)
|
|
1041
|
+
fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
|
|
1042
|
+
new_pan_feature_map = pan_block(fused_feature_map)
|
|
1043
|
+
pan_feature_maps.append(new_pan_feature_map)
|
|
1044
|
+
|
|
1045
|
+
return BaseModelOutput(last_hidden_state=pan_feature_maps)
|
|
583
1046
|
|
|
584
1047
|
|
|
585
1048
|
def inverse_sigmoid(x, eps=1e-5):
|
|
@@ -647,6 +1110,12 @@ class DFineDecoder(DFinePreTrainedModel):
|
|
|
647
1110
|
to improve bounding box accuracy and robustness.
|
|
648
1111
|
"""
|
|
649
1112
|
|
|
1113
|
+
_can_record_outputs = {
|
|
1114
|
+
"hidden_states": DFineDecoderLayer,
|
|
1115
|
+
"attentions": DFineSelfAttention,
|
|
1116
|
+
"cross_attentions": DFineMultiscaleDeformableAttention,
|
|
1117
|
+
}
|
|
1118
|
+
|
|
650
1119
|
def __init__(self, config: DFineConfig):
|
|
651
1120
|
super().__init__(config)
|
|
652
1121
|
self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx
|
|
@@ -656,7 +1125,7 @@ class DFineDecoder(DFinePreTrainedModel):
|
|
|
656
1125
|
[DFineDecoderLayer(config) for _ in range(config.decoder_layers)]
|
|
657
1126
|
+ [DFineDecoderLayer(config) for _ in range(config.decoder_layers - self.eval_idx - 1)]
|
|
658
1127
|
)
|
|
659
|
-
self.query_pos_head = DFineMLPPredictionHead(
|
|
1128
|
+
self.query_pos_head = DFineMLPPredictionHead(4, 2 * config.d_model, config.d_model, num_layers=2)
|
|
660
1129
|
|
|
661
1130
|
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
|
|
662
1131
|
self.bbox_embed = None
|
|
@@ -674,6 +1143,7 @@ class DFineDecoder(DFinePreTrainedModel):
|
|
|
674
1143
|
# Initialize weights and apply final processing
|
|
675
1144
|
self.post_init()
|
|
676
1145
|
|
|
1146
|
+
@check_model_inputs()
|
|
677
1147
|
def forward(
|
|
678
1148
|
self,
|
|
679
1149
|
encoder_hidden_states: torch.Tensor,
|
|
@@ -682,12 +1152,9 @@ class DFineDecoder(DFinePreTrainedModel):
|
|
|
682
1152
|
spatial_shapes,
|
|
683
1153
|
level_start_index=None,
|
|
684
1154
|
spatial_shapes_list=None,
|
|
685
|
-
output_hidden_states=None,
|
|
686
1155
|
encoder_attention_mask=None,
|
|
687
1156
|
memory_mask=None,
|
|
688
|
-
|
|
689
|
-
return_dict=None,
|
|
690
|
-
**kwargs,
|
|
1157
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
691
1158
|
) -> DFineDecoderOutput:
|
|
692
1159
|
r"""
|
|
693
1160
|
Args:
|
|
@@ -701,39 +1168,17 @@ class DFineDecoder(DFinePreTrainedModel):
|
|
|
701
1168
|
in `[0, 1]`:
|
|
702
1169
|
- 1 for pixels that are real (i.e. **not masked**),
|
|
703
1170
|
- 0 for pixels that are padding (i.e. **masked**).
|
|
704
|
-
position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
705
|
-
Position embeddings that are added to the queries and keys in each self-attention layer.
|
|
706
1171
|
reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
|
|
707
1172
|
Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
|
|
708
1173
|
spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
|
|
709
1174
|
Spatial shapes of the feature maps.
|
|
710
1175
|
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
|
|
711
1176
|
Indexes for the start of each feature level. In range `[0, sequence_length]`.
|
|
712
|
-
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
|
|
713
|
-
Ratio of valid area in each feature level.
|
|
714
|
-
|
|
715
|
-
output_attentions (`bool`, *optional*):
|
|
716
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
717
|
-
returned tensors for more detail.
|
|
718
|
-
output_hidden_states (`bool`, *optional*):
|
|
719
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
720
|
-
for more detail.
|
|
721
|
-
return_dict (`bool`, *optional*):
|
|
722
|
-
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
723
1177
|
"""
|
|
724
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
725
|
-
output_hidden_states = (
|
|
726
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
727
|
-
)
|
|
728
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
729
|
-
|
|
730
1178
|
if inputs_embeds is not None:
|
|
731
1179
|
hidden_states = inputs_embeds
|
|
732
1180
|
|
|
733
1181
|
# decoder layers
|
|
734
|
-
all_hidden_states = () if output_hidden_states else None
|
|
735
|
-
all_self_attns = () if output_attentions else None
|
|
736
|
-
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
737
1182
|
intermediate = ()
|
|
738
1183
|
intermediate_reference_points = ()
|
|
739
1184
|
intermediate_logits = ()
|
|
@@ -749,25 +1194,22 @@ class DFineDecoder(DFinePreTrainedModel):
|
|
|
749
1194
|
ref_points_input = ref_points_detach.unsqueeze(2)
|
|
750
1195
|
query_pos_embed = self.query_pos_head(ref_points_detach).clamp(min=-10, max=10)
|
|
751
1196
|
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
output = decoder_layer(
|
|
756
|
-
hidden_states=hidden_states,
|
|
1197
|
+
hidden_states = decoder_layer(
|
|
1198
|
+
hidden_states,
|
|
757
1199
|
position_embeddings=query_pos_embed,
|
|
758
1200
|
reference_points=ref_points_input,
|
|
759
1201
|
spatial_shapes=spatial_shapes,
|
|
760
1202
|
spatial_shapes_list=spatial_shapes_list,
|
|
761
1203
|
encoder_hidden_states=encoder_hidden_states,
|
|
762
1204
|
encoder_attention_mask=encoder_attention_mask,
|
|
763
|
-
|
|
1205
|
+
**kwargs,
|
|
764
1206
|
)
|
|
765
1207
|
|
|
766
|
-
hidden_states = output[0]
|
|
767
|
-
|
|
768
1208
|
if i == 0:
|
|
769
1209
|
# Initial bounding box predictions with inverse sigmoid refinement
|
|
770
|
-
new_reference_points = F.sigmoid(
|
|
1210
|
+
new_reference_points = F.sigmoid(
|
|
1211
|
+
self.pre_bbox_head(hidden_states) + inverse_sigmoid(ref_points_detach)
|
|
1212
|
+
)
|
|
771
1213
|
ref_points_initial = new_reference_points.detach()
|
|
772
1214
|
|
|
773
1215
|
# Refine bounding box corners using FDR, integrating previous layer's corrections
|
|
@@ -796,12 +1238,6 @@ class DFineDecoder(DFinePreTrainedModel):
|
|
|
796
1238
|
initial_reference_points += (ref_points_initial,)
|
|
797
1239
|
intermediate_predicted_corners += (pred_corners,)
|
|
798
1240
|
|
|
799
|
-
if output_attentions:
|
|
800
|
-
all_self_attns += (output[1],)
|
|
801
|
-
|
|
802
|
-
if encoder_hidden_states is not None:
|
|
803
|
-
all_cross_attentions += (output[2],)
|
|
804
|
-
|
|
805
1241
|
# Keep batch_size as first dimension
|
|
806
1242
|
intermediate = torch.stack(intermediate)
|
|
807
1243
|
if self.class_embed is not None and self.bbox_embed is not None:
|
|
@@ -810,27 +1246,6 @@ class DFineDecoder(DFinePreTrainedModel):
|
|
|
810
1246
|
initial_reference_points = torch.stack(initial_reference_points, dim=1)
|
|
811
1247
|
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
|
|
812
1248
|
|
|
813
|
-
# add hidden states from the last decoder layer
|
|
814
|
-
if output_hidden_states:
|
|
815
|
-
all_hidden_states += (hidden_states,)
|
|
816
|
-
|
|
817
|
-
if not return_dict:
|
|
818
|
-
return tuple(
|
|
819
|
-
v
|
|
820
|
-
for v in [
|
|
821
|
-
hidden_states,
|
|
822
|
-
intermediate,
|
|
823
|
-
intermediate_logits,
|
|
824
|
-
intermediate_reference_points,
|
|
825
|
-
intermediate_predicted_corners,
|
|
826
|
-
initial_reference_points,
|
|
827
|
-
all_hidden_states,
|
|
828
|
-
all_self_attns,
|
|
829
|
-
all_cross_attentions,
|
|
830
|
-
]
|
|
831
|
-
if v is not None
|
|
832
|
-
)
|
|
833
|
-
|
|
834
1249
|
return DFineDecoderOutput(
|
|
835
1250
|
last_hidden_state=hidden_states,
|
|
836
1251
|
intermediate_hidden_states=intermediate,
|
|
@@ -838,51 +1253,9 @@ class DFineDecoder(DFinePreTrainedModel):
|
|
|
838
1253
|
intermediate_reference_points=intermediate_reference_points,
|
|
839
1254
|
intermediate_predicted_corners=intermediate_predicted_corners,
|
|
840
1255
|
initial_reference_points=initial_reference_points,
|
|
841
|
-
hidden_states=all_hidden_states,
|
|
842
|
-
attentions=all_self_attns,
|
|
843
|
-
cross_attentions=all_cross_attentions,
|
|
844
1256
|
)
|
|
845
1257
|
|
|
846
1258
|
|
|
847
|
-
class DFineFrozenBatchNorm2d(nn.Module):
|
|
848
|
-
"""
|
|
849
|
-
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
850
|
-
|
|
851
|
-
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
|
|
852
|
-
torchvision.models.resnet[18,34,50,101] produce nans.
|
|
853
|
-
"""
|
|
854
|
-
|
|
855
|
-
def __init__(self, n):
|
|
856
|
-
super().__init__()
|
|
857
|
-
self.register_buffer("weight", torch.ones(n))
|
|
858
|
-
self.register_buffer("bias", torch.zeros(n))
|
|
859
|
-
self.register_buffer("running_mean", torch.zeros(n))
|
|
860
|
-
self.register_buffer("running_var", torch.ones(n))
|
|
861
|
-
|
|
862
|
-
def _load_from_state_dict(
|
|
863
|
-
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
864
|
-
):
|
|
865
|
-
num_batches_tracked_key = prefix + "num_batches_tracked"
|
|
866
|
-
if num_batches_tracked_key in state_dict:
|
|
867
|
-
del state_dict[num_batches_tracked_key]
|
|
868
|
-
|
|
869
|
-
super()._load_from_state_dict(
|
|
870
|
-
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
871
|
-
)
|
|
872
|
-
|
|
873
|
-
def forward(self, x):
|
|
874
|
-
# move reshapes to the beginning
|
|
875
|
-
# to make it user-friendly
|
|
876
|
-
weight = self.weight.reshape(1, -1, 1, 1)
|
|
877
|
-
bias = self.bias.reshape(1, -1, 1, 1)
|
|
878
|
-
running_var = self.running_var.reshape(1, -1, 1, 1)
|
|
879
|
-
running_mean = self.running_mean.reshape(1, -1, 1, 1)
|
|
880
|
-
epsilon = 1e-5
|
|
881
|
-
scale = weight * (running_var + epsilon).rsqrt()
|
|
882
|
-
bias = bias - running_mean * scale
|
|
883
|
-
return x * scale + bias
|
|
884
|
-
|
|
885
|
-
|
|
886
1259
|
@dataclass
|
|
887
1260
|
@auto_docstring(
|
|
888
1261
|
custom_intro="""
|
|
@@ -1134,8 +1507,8 @@ class DFineModel(DFinePreTrainedModel):
|
|
|
1134
1507
|
intermediate_channel_sizes = self.backbone.intermediate_channel_sizes
|
|
1135
1508
|
num_backbone_outs = len(config.decoder_in_channels)
|
|
1136
1509
|
encoder_input_proj_list = []
|
|
1137
|
-
for
|
|
1138
|
-
in_channels = intermediate_channel_sizes[
|
|
1510
|
+
for i in range(num_backbone_outs):
|
|
1511
|
+
in_channels = intermediate_channel_sizes[i]
|
|
1139
1512
|
encoder_input_proj_list.append(
|
|
1140
1513
|
nn.Sequential(
|
|
1141
1514
|
nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
|
|
@@ -1161,15 +1534,15 @@ class DFineModel(DFinePreTrainedModel):
|
|
|
1161
1534
|
nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
|
|
1162
1535
|
)
|
|
1163
1536
|
self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
|
|
1164
|
-
self.enc_bbox_head = DFineMLPPredictionHead(config
|
|
1537
|
+
self.enc_bbox_head = DFineMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
|
|
1165
1538
|
|
|
1166
1539
|
# init encoder output anchors and valid_mask
|
|
1167
1540
|
if config.anchor_image_size:
|
|
1168
1541
|
self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype)
|
|
1169
1542
|
num_backbone_outs = len(config.decoder_in_channels)
|
|
1170
1543
|
decoder_input_proj_list = []
|
|
1171
|
-
for
|
|
1172
|
-
in_channels = config.decoder_in_channels[
|
|
1544
|
+
for i in range(num_backbone_outs):
|
|
1545
|
+
in_channels = config.decoder_in_channels[i]
|
|
1173
1546
|
decoder_input_proj_list.append(
|
|
1174
1547
|
nn.Sequential(
|
|
1175
1548
|
nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
|
|
@@ -1243,26 +1616,20 @@ class DFineModel(DFinePreTrainedModel):
|
|
|
1243
1616
|
return anchors, valid_mask
|
|
1244
1617
|
|
|
1245
1618
|
@auto_docstring
|
|
1619
|
+
@can_return_tuple
|
|
1246
1620
|
def forward(
|
|
1247
1621
|
self,
|
|
1248
1622
|
pixel_values: torch.FloatTensor,
|
|
1249
1623
|
pixel_mask: torch.LongTensor | None = None,
|
|
1250
1624
|
encoder_outputs: torch.FloatTensor | None = None,
|
|
1251
1625
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1252
|
-
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1253
1626
|
labels: list[dict] | None = None,
|
|
1254
|
-
|
|
1255
|
-
output_hidden_states: bool | None = None,
|
|
1256
|
-
return_dict: bool | None = None,
|
|
1257
|
-
**kwargs,
|
|
1627
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1258
1628
|
) -> tuple[torch.FloatTensor] | DFineModelOutput:
|
|
1259
1629
|
r"""
|
|
1260
1630
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1261
1631
|
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
1262
1632
|
can choose to directly pass a flattened representation of an image.
|
|
1263
|
-
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1264
|
-
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1265
|
-
embedded representation.
|
|
1266
1633
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1267
1634
|
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
|
1268
1635
|
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
|
@@ -1290,53 +1657,46 @@ class DFineModel(DFinePreTrainedModel):
|
|
|
1290
1657
|
>>> list(last_hidden_states.shape)
|
|
1291
1658
|
[1, 300, 256]
|
|
1292
1659
|
```"""
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
|
|
1660
|
+
if pixel_values is None and inputs_embeds is None:
|
|
1661
|
+
raise ValueError("You have to specify either pixel_values or inputs_embeds")
|
|
1662
|
+
|
|
1663
|
+
if inputs_embeds is None:
|
|
1664
|
+
batch_size, num_channels, height, width = pixel_values.shape
|
|
1665
|
+
device = pixel_values.device
|
|
1666
|
+
if pixel_mask is None:
|
|
1667
|
+
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
|
1668
|
+
features = self.backbone(pixel_values, pixel_mask)
|
|
1669
|
+
proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
|
|
1670
|
+
else:
|
|
1671
|
+
batch_size = inputs_embeds.shape[0]
|
|
1672
|
+
device = inputs_embeds.device
|
|
1673
|
+
proj_feats = inputs_embeds
|
|
1308
1674
|
|
|
1309
1675
|
if encoder_outputs is None:
|
|
1310
1676
|
encoder_outputs = self.encoder(
|
|
1311
1677
|
proj_feats,
|
|
1312
|
-
|
|
1313
|
-
output_hidden_states=output_hidden_states,
|
|
1314
|
-
return_dict=return_dict,
|
|
1678
|
+
**kwargs,
|
|
1315
1679
|
)
|
|
1316
|
-
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
|
1317
|
-
elif
|
|
1680
|
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
|
1681
|
+
elif not isinstance(encoder_outputs, BaseModelOutput):
|
|
1318
1682
|
encoder_outputs = BaseModelOutput(
|
|
1319
1683
|
last_hidden_state=encoder_outputs[0],
|
|
1320
|
-
hidden_states=encoder_outputs[1] if
|
|
1321
|
-
attentions=encoder_outputs[2]
|
|
1322
|
-
if len(encoder_outputs) > 2
|
|
1323
|
-
else encoder_outputs[1]
|
|
1324
|
-
if output_attentions
|
|
1325
|
-
else None,
|
|
1684
|
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1685
|
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1326
1686
|
)
|
|
1327
1687
|
|
|
1328
1688
|
# Equivalent to def _get_encoder_input
|
|
1329
1689
|
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/DFine_pytorch/src/zoo/DFine/DFine_decoder.py#L412
|
|
1330
1690
|
sources = []
|
|
1331
|
-
for level, source in enumerate(encoder_outputs
|
|
1691
|
+
for level, source in enumerate(encoder_outputs.last_hidden_state):
|
|
1332
1692
|
sources.append(self.decoder_input_proj[level](source))
|
|
1333
1693
|
|
|
1334
1694
|
# Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
|
|
1335
1695
|
if self.config.num_feature_levels > len(sources):
|
|
1336
1696
|
_len_sources = len(sources)
|
|
1337
|
-
sources.append(self.decoder_input_proj[_len_sources](encoder_outputs
|
|
1697
|
+
sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state)[-1])
|
|
1338
1698
|
for i in range(_len_sources + 1, self.config.num_feature_levels):
|
|
1339
|
-
sources.append(self.decoder_input_proj[i](encoder_outputs[
|
|
1699
|
+
sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1]))
|
|
1340
1700
|
|
|
1341
1701
|
# Prepare encoder inputs (by flattening)
|
|
1342
1702
|
source_flatten = []
|
|
@@ -1428,22 +1788,9 @@ class DFineModel(DFinePreTrainedModel):
|
|
|
1428
1788
|
spatial_shapes=spatial_shapes,
|
|
1429
1789
|
spatial_shapes_list=spatial_shapes_list,
|
|
1430
1790
|
level_start_index=level_start_index,
|
|
1431
|
-
|
|
1432
|
-
output_hidden_states=output_hidden_states,
|
|
1433
|
-
return_dict=return_dict,
|
|
1791
|
+
**kwargs,
|
|
1434
1792
|
)
|
|
1435
1793
|
|
|
1436
|
-
if not return_dict:
|
|
1437
|
-
enc_outputs = tuple(
|
|
1438
|
-
value
|
|
1439
|
-
for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits]
|
|
1440
|
-
if value is not None
|
|
1441
|
-
)
|
|
1442
|
-
dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values])
|
|
1443
|
-
tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs
|
|
1444
|
-
|
|
1445
|
-
return tuple_outputs
|
|
1446
|
-
|
|
1447
1794
|
return DFineModelOutput(
|
|
1448
1795
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
1449
1796
|
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
|
|
@@ -1555,10 +1902,10 @@ class DFineForObjectDetection(DFinePreTrainedModel):
|
|
|
1555
1902
|
# We can't initialize the model on meta device as some weights are modified during the initialization
|
|
1556
1903
|
_no_split_modules = None
|
|
1557
1904
|
_tied_weights_keys = {
|
|
1558
|
-
r"bbox_embed.(?![0])\d+": "bbox_embed.0",
|
|
1559
|
-
r"class_embed.(?![0])\d+": "class_embed.0",
|
|
1560
|
-
"
|
|
1561
|
-
"
|
|
1905
|
+
r"bbox_embed.(?![0])\d+": r"bbox_embed.0",
|
|
1906
|
+
r"class_embed.(?![0])\d+": r"^class_embed.0",
|
|
1907
|
+
"class_embed": "model.decoder.class_embed",
|
|
1908
|
+
"bbox_embed": "model.decoder.bbox_embed",
|
|
1562
1909
|
}
|
|
1563
1910
|
|
|
1564
1911
|
def __init__(self, config: DFineConfig):
|
|
@@ -1590,18 +1937,15 @@ class DFineForObjectDetection(DFinePreTrainedModel):
|
|
|
1590
1937
|
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
|
|
1591
1938
|
|
|
1592
1939
|
@auto_docstring
|
|
1940
|
+
@can_return_tuple
|
|
1593
1941
|
def forward(
|
|
1594
1942
|
self,
|
|
1595
1943
|
pixel_values: torch.FloatTensor,
|
|
1596
1944
|
pixel_mask: torch.LongTensor | None = None,
|
|
1597
1945
|
encoder_outputs: torch.FloatTensor | None = None,
|
|
1598
1946
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1599
|
-
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1600
1947
|
labels: list[dict] | None = None,
|
|
1601
|
-
|
|
1602
|
-
output_hidden_states: bool | None = None,
|
|
1603
|
-
return_dict: bool | None = None,
|
|
1604
|
-
**kwargs,
|
|
1948
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1605
1949
|
) -> tuple[torch.FloatTensor] | DFineObjectDetectionOutput:
|
|
1606
1950
|
r"""
|
|
1607
1951
|
Example:
|
|
@@ -1648,40 +1992,29 @@ class DFineForObjectDetection(DFinePreTrainedModel):
|
|
|
1648
1992
|
Detected sofa with confidence 0.918 at location [0.59, 1.88, 640.25, 474.74]
|
|
1649
1993
|
```
|
|
1650
1994
|
"""
|
|
1651
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1652
|
-
output_hidden_states = (
|
|
1653
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1654
|
-
)
|
|
1655
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1656
|
-
|
|
1657
1995
|
outputs = self.model(
|
|
1658
1996
|
pixel_values,
|
|
1659
1997
|
pixel_mask=pixel_mask,
|
|
1660
1998
|
encoder_outputs=encoder_outputs,
|
|
1661
1999
|
inputs_embeds=inputs_embeds,
|
|
1662
|
-
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
1663
2000
|
labels=labels,
|
|
1664
|
-
|
|
1665
|
-
output_hidden_states=output_hidden_states,
|
|
1666
|
-
return_dict=return_dict,
|
|
2001
|
+
**kwargs,
|
|
1667
2002
|
)
|
|
1668
2003
|
|
|
1669
|
-
denoising_meta_values =
|
|
1670
|
-
outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None
|
|
1671
|
-
)
|
|
2004
|
+
denoising_meta_values = outputs.denoising_meta_values if self.training else None
|
|
1672
2005
|
|
|
1673
|
-
outputs_class = outputs.intermediate_logits
|
|
1674
|
-
outputs_coord = outputs.intermediate_reference_points
|
|
1675
|
-
predicted_corners = outputs.intermediate_predicted_corners
|
|
1676
|
-
initial_reference_points = outputs.initial_reference_points
|
|
2006
|
+
outputs_class = outputs.intermediate_logits
|
|
2007
|
+
outputs_coord = outputs.intermediate_reference_points
|
|
2008
|
+
predicted_corners = outputs.intermediate_predicted_corners
|
|
2009
|
+
initial_reference_points = outputs.initial_reference_points
|
|
1677
2010
|
|
|
1678
2011
|
logits = outputs_class[:, -1]
|
|
1679
2012
|
pred_boxes = outputs_coord[:, -1]
|
|
1680
2013
|
|
|
1681
2014
|
loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
|
|
1682
2015
|
if labels is not None:
|
|
1683
|
-
enc_topk_logits = outputs.enc_topk_logits
|
|
1684
|
-
enc_topk_bboxes = outputs.enc_topk_bboxes
|
|
2016
|
+
enc_topk_logits = outputs.enc_topk_logits
|
|
2017
|
+
enc_topk_bboxes = outputs.enc_topk_bboxes
|
|
1685
2018
|
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
|
1686
2019
|
logits,
|
|
1687
2020
|
labels,
|
|
@@ -1698,13 +2031,6 @@ class DFineForObjectDetection(DFinePreTrainedModel):
|
|
|
1698
2031
|
**kwargs,
|
|
1699
2032
|
)
|
|
1700
2033
|
|
|
1701
|
-
if not return_dict:
|
|
1702
|
-
if auxiliary_outputs is not None:
|
|
1703
|
-
output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs
|
|
1704
|
-
else:
|
|
1705
|
-
output = (logits, pred_boxes) + outputs
|
|
1706
|
-
return ((loss, loss_dict) + output) if loss is not None else output
|
|
1707
|
-
|
|
1708
2034
|
return DFineObjectDetectionOutput(
|
|
1709
2035
|
loss=loss,
|
|
1710
2036
|
loss_dict=loss_dict,
|
|
@@ -1732,470 +2058,4 @@ class DFineForObjectDetection(DFinePreTrainedModel):
|
|
|
1732
2058
|
)
|
|
1733
2059
|
|
|
1734
2060
|
|
|
1735
|
-
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
1736
|
-
class DFineMLPPredictionHead(nn.Module):
|
|
1737
|
-
"""
|
|
1738
|
-
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
|
1739
|
-
height and width of a bounding box w.r.t. an image.
|
|
1740
|
-
|
|
1741
|
-
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
1742
|
-
Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/DFine_paddle/ppdet/modeling/transformers/utils.py#L453
|
|
1743
|
-
|
|
1744
|
-
"""
|
|
1745
|
-
|
|
1746
|
-
def __init__(self, config, input_dim, d_model, output_dim, num_layers):
|
|
1747
|
-
super().__init__()
|
|
1748
|
-
self.num_layers = num_layers
|
|
1749
|
-
h = [d_model] * (num_layers - 1)
|
|
1750
|
-
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
|
1751
|
-
|
|
1752
|
-
def forward(self, x):
|
|
1753
|
-
for i, layer in enumerate(self.layers):
|
|
1754
|
-
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
1755
|
-
return x
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
class DFineMLP(nn.Module):
|
|
1759
|
-
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"):
|
|
1760
|
-
super().__init__()
|
|
1761
|
-
self.num_layers = num_layers
|
|
1762
|
-
hidden_dims = [hidden_dim] * (num_layers - 1)
|
|
1763
|
-
input_dims = [input_dim] + hidden_dims
|
|
1764
|
-
output_dims = hidden_dims + [output_dim]
|
|
1765
|
-
self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims))
|
|
1766
|
-
self.act = ACT2CLS[act]()
|
|
1767
|
-
|
|
1768
|
-
def forward(self, stat_features: torch.Tensor) -> torch.Tensor:
|
|
1769
|
-
for i, layer in enumerate(self.layers):
|
|
1770
|
-
stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features)
|
|
1771
|
-
return stat_features
|
|
1772
|
-
|
|
1773
|
-
|
|
1774
|
-
class DFineLQE(nn.Module):
|
|
1775
|
-
def __init__(self, config: DFineConfig):
|
|
1776
|
-
super().__init__()
|
|
1777
|
-
self.top_prob_values = config.top_prob_values
|
|
1778
|
-
self.max_num_bins = config.max_num_bins
|
|
1779
|
-
self.reg_conf = DFineMLP(4 * (self.top_prob_values + 1), config.lqe_hidden_dim, 1, config.lqe_layers)
|
|
1780
|
-
|
|
1781
|
-
def forward(self, scores: torch.Tensor, pred_corners: torch.Tensor) -> torch.Tensor:
|
|
1782
|
-
batch_size, length, _ = pred_corners.size()
|
|
1783
|
-
prob = F.softmax(pred_corners.reshape(batch_size, length, 4, self.max_num_bins + 1), dim=-1)
|
|
1784
|
-
prob_topk, _ = prob.topk(self.top_prob_values, dim=-1)
|
|
1785
|
-
stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1)
|
|
1786
|
-
quality_score = self.reg_conf(stat.reshape(batch_size, length, -1))
|
|
1787
|
-
scores = scores + quality_score
|
|
1788
|
-
return scores
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
class DFineConvNormLayer(nn.Module):
|
|
1792
|
-
def __init__(
|
|
1793
|
-
self,
|
|
1794
|
-
config: DFineConfig,
|
|
1795
|
-
in_channels: int,
|
|
1796
|
-
out_channels: int,
|
|
1797
|
-
kernel_size: int,
|
|
1798
|
-
stride: int,
|
|
1799
|
-
groups: int = 1,
|
|
1800
|
-
padding: int | None = None,
|
|
1801
|
-
activation: str | None = None,
|
|
1802
|
-
):
|
|
1803
|
-
super().__init__()
|
|
1804
|
-
self.conv = nn.Conv2d(
|
|
1805
|
-
in_channels,
|
|
1806
|
-
out_channels,
|
|
1807
|
-
kernel_size,
|
|
1808
|
-
stride,
|
|
1809
|
-
groups=groups,
|
|
1810
|
-
padding=(kernel_size - 1) // 2 if padding is None else padding,
|
|
1811
|
-
bias=False,
|
|
1812
|
-
)
|
|
1813
|
-
self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
|
|
1814
|
-
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
|
1815
|
-
|
|
1816
|
-
def forward(self, hidden_state):
|
|
1817
|
-
hidden_state = self.conv(hidden_state)
|
|
1818
|
-
hidden_state = self.norm(hidden_state)
|
|
1819
|
-
hidden_state = self.activation(hidden_state)
|
|
1820
|
-
return hidden_state
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
class DFineRepVggBlock(nn.Module):
|
|
1824
|
-
"""
|
|
1825
|
-
RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
|
|
1826
|
-
"""
|
|
1827
|
-
|
|
1828
|
-
def __init__(self, config: DFineConfig, in_channels: int, out_channels: int):
|
|
1829
|
-
super().__init__()
|
|
1830
|
-
|
|
1831
|
-
activation = config.activation_function
|
|
1832
|
-
hidden_channels = in_channels
|
|
1833
|
-
self.conv1 = DFineConvNormLayer(config, hidden_channels, out_channels, 3, 1, padding=1)
|
|
1834
|
-
self.conv2 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, padding=0)
|
|
1835
|
-
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
|
1836
|
-
|
|
1837
|
-
def forward(self, x):
|
|
1838
|
-
y = self.conv1(x) + self.conv2(x)
|
|
1839
|
-
return self.activation(y)
|
|
1840
|
-
|
|
1841
|
-
|
|
1842
|
-
class DFineCSPRepLayer(nn.Module):
|
|
1843
|
-
"""
|
|
1844
|
-
Cross Stage Partial (CSP) network layer with RepVGG blocks.
|
|
1845
|
-
"""
|
|
1846
|
-
|
|
1847
|
-
def __init__(
|
|
1848
|
-
self, config: DFineConfig, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0
|
|
1849
|
-
):
|
|
1850
|
-
super().__init__()
|
|
1851
|
-
activation = config.activation_function
|
|
1852
|
-
|
|
1853
|
-
hidden_channels = int(out_channels * expansion)
|
|
1854
|
-
self.conv1 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
|
1855
|
-
self.conv2 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
|
1856
|
-
self.bottlenecks = nn.ModuleList(
|
|
1857
|
-
[DFineRepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)]
|
|
1858
|
-
)
|
|
1859
|
-
if hidden_channels != out_channels:
|
|
1860
|
-
self.conv3 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
|
|
1861
|
-
else:
|
|
1862
|
-
self.conv3 = nn.Identity()
|
|
1863
|
-
|
|
1864
|
-
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
|
1865
|
-
hidden_state_1 = self.conv1(hidden_state)
|
|
1866
|
-
for bottleneck in self.bottlenecks:
|
|
1867
|
-
hidden_state_1 = bottleneck(hidden_state_1)
|
|
1868
|
-
hidden_state_2 = self.conv2(hidden_state)
|
|
1869
|
-
hidden_state_3 = self.conv3(hidden_state_1 + hidden_state_2)
|
|
1870
|
-
return hidden_state_3
|
|
1871
|
-
|
|
1872
|
-
|
|
1873
|
-
class DFineRepNCSPELAN4(nn.Module):
|
|
1874
|
-
def __init__(self, config: DFineConfig, act: str = "silu", numb_blocks: int = 3):
|
|
1875
|
-
super().__init__()
|
|
1876
|
-
conv1_dim = config.encoder_hidden_dim * 2
|
|
1877
|
-
conv2_dim = config.encoder_hidden_dim
|
|
1878
|
-
conv3_dim = config.encoder_hidden_dim * 2
|
|
1879
|
-
conv4_dim = round(config.hidden_expansion * config.encoder_hidden_dim // 2)
|
|
1880
|
-
self.conv_dim = conv3_dim // 2
|
|
1881
|
-
self.conv1 = DFineConvNormLayer(config, conv1_dim, conv3_dim, 1, 1, activation=act)
|
|
1882
|
-
self.csp_rep1 = DFineCSPRepLayer(config, conv3_dim // 2, conv4_dim, num_blocks=numb_blocks)
|
|
1883
|
-
self.conv2 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
|
|
1884
|
-
self.csp_rep2 = DFineCSPRepLayer(config, conv4_dim, conv4_dim, num_blocks=numb_blocks)
|
|
1885
|
-
self.conv3 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
|
|
1886
|
-
self.conv4 = DFineConvNormLayer(config, conv3_dim + (2 * conv4_dim), conv2_dim, 1, 1, activation=act)
|
|
1887
|
-
|
|
1888
|
-
def forward(self, input_features: torch.Tensor) -> torch.Tensor:
|
|
1889
|
-
# Split initial features into two branches after first convolution
|
|
1890
|
-
split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1))
|
|
1891
|
-
|
|
1892
|
-
# Process branches sequentially
|
|
1893
|
-
branch1 = self.csp_rep1(split_features[-1])
|
|
1894
|
-
branch1 = self.conv2(branch1)
|
|
1895
|
-
branch2 = self.csp_rep2(branch1)
|
|
1896
|
-
branch2 = self.conv3(branch2)
|
|
1897
|
-
|
|
1898
|
-
split_features.extend([branch1, branch2])
|
|
1899
|
-
merged_features = torch.cat(split_features, 1)
|
|
1900
|
-
merged_features = self.conv4(merged_features)
|
|
1901
|
-
return merged_features
|
|
1902
|
-
|
|
1903
|
-
|
|
1904
|
-
class DFineSCDown(nn.Module):
|
|
1905
|
-
def __init__(self, config: DFineConfig, kernel_size: int, stride: int):
|
|
1906
|
-
super().__init__()
|
|
1907
|
-
self.conv1 = DFineConvNormLayer(config, config.encoder_hidden_dim, config.encoder_hidden_dim, 1, 1)
|
|
1908
|
-
self.conv2 = DFineConvNormLayer(
|
|
1909
|
-
config,
|
|
1910
|
-
config.encoder_hidden_dim,
|
|
1911
|
-
config.encoder_hidden_dim,
|
|
1912
|
-
kernel_size,
|
|
1913
|
-
stride,
|
|
1914
|
-
config.encoder_hidden_dim,
|
|
1915
|
-
)
|
|
1916
|
-
|
|
1917
|
-
def forward(self, input_features: torch.Tensor) -> torch.Tensor:
|
|
1918
|
-
input_features = self.conv1(input_features)
|
|
1919
|
-
input_features = self.conv2(input_features)
|
|
1920
|
-
return input_features
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
class DFineEncoderLayer(nn.Module):
|
|
1924
|
-
def __init__(self, config: DFineConfig):
|
|
1925
|
-
super().__init__()
|
|
1926
|
-
self.normalize_before = config.normalize_before
|
|
1927
|
-
|
|
1928
|
-
# self-attention
|
|
1929
|
-
self.self_attn = DFineMultiheadAttention(
|
|
1930
|
-
embed_dim=config.encoder_hidden_dim,
|
|
1931
|
-
num_heads=config.num_attention_heads,
|
|
1932
|
-
dropout=config.dropout,
|
|
1933
|
-
)
|
|
1934
|
-
self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
|
|
1935
|
-
self.dropout = config.dropout
|
|
1936
|
-
self.activation_fn = ACT2FN[config.encoder_activation_function]
|
|
1937
|
-
self.activation_dropout = config.activation_dropout
|
|
1938
|
-
self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim)
|
|
1939
|
-
self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim)
|
|
1940
|
-
self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
|
|
1941
|
-
|
|
1942
|
-
def forward(
|
|
1943
|
-
self,
|
|
1944
|
-
hidden_states: torch.Tensor,
|
|
1945
|
-
attention_mask: torch.Tensor,
|
|
1946
|
-
position_embeddings: torch.Tensor | None = None,
|
|
1947
|
-
output_attentions: bool = False,
|
|
1948
|
-
**kwargs,
|
|
1949
|
-
):
|
|
1950
|
-
"""
|
|
1951
|
-
Args:
|
|
1952
|
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
1953
|
-
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
1954
|
-
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
1955
|
-
values.
|
|
1956
|
-
position_embeddings (`torch.FloatTensor`, *optional*):
|
|
1957
|
-
Object queries (also called content embeddings), to be added to the hidden states.
|
|
1958
|
-
output_attentions (`bool`, *optional*):
|
|
1959
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
1960
|
-
returned tensors for more detail.
|
|
1961
|
-
"""
|
|
1962
|
-
residual = hidden_states
|
|
1963
|
-
if self.normalize_before:
|
|
1964
|
-
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
1965
|
-
|
|
1966
|
-
hidden_states, attn_weights = self.self_attn(
|
|
1967
|
-
hidden_states=hidden_states,
|
|
1968
|
-
attention_mask=attention_mask,
|
|
1969
|
-
position_embeddings=position_embeddings,
|
|
1970
|
-
output_attentions=output_attentions,
|
|
1971
|
-
)
|
|
1972
|
-
|
|
1973
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
1974
|
-
hidden_states = residual + hidden_states
|
|
1975
|
-
if not self.normalize_before:
|
|
1976
|
-
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
1977
|
-
|
|
1978
|
-
if self.normalize_before:
|
|
1979
|
-
hidden_states = self.final_layer_norm(hidden_states)
|
|
1980
|
-
residual = hidden_states
|
|
1981
|
-
|
|
1982
|
-
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
1983
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
1984
|
-
|
|
1985
|
-
hidden_states = self.fc2(hidden_states)
|
|
1986
|
-
|
|
1987
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
1988
|
-
|
|
1989
|
-
hidden_states = residual + hidden_states
|
|
1990
|
-
if not self.normalize_before:
|
|
1991
|
-
hidden_states = self.final_layer_norm(hidden_states)
|
|
1992
|
-
|
|
1993
|
-
if self.training:
|
|
1994
|
-
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
|
1995
|
-
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
1996
|
-
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
1997
|
-
|
|
1998
|
-
outputs = (hidden_states,)
|
|
1999
|
-
|
|
2000
|
-
if output_attentions:
|
|
2001
|
-
outputs += (attn_weights,)
|
|
2002
|
-
|
|
2003
|
-
return outputs
|
|
2004
|
-
|
|
2005
|
-
|
|
2006
|
-
class DFineEncoder(nn.Module):
|
|
2007
|
-
def __init__(self, config: DFineConfig):
|
|
2008
|
-
super().__init__()
|
|
2009
|
-
|
|
2010
|
-
self.layers = nn.ModuleList([DFineEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
2011
|
-
|
|
2012
|
-
def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor:
|
|
2013
|
-
hidden_states = src
|
|
2014
|
-
for layer in self.layers:
|
|
2015
|
-
hidden_states = layer(
|
|
2016
|
-
hidden_states,
|
|
2017
|
-
attention_mask=src_mask,
|
|
2018
|
-
position_embeddings=pos_embed,
|
|
2019
|
-
output_attentions=output_attentions,
|
|
2020
|
-
)
|
|
2021
|
-
return hidden_states
|
|
2022
|
-
|
|
2023
|
-
|
|
2024
|
-
class DFineHybridEncoder(nn.Module):
|
|
2025
|
-
"""
|
|
2026
|
-
Decoder consisting of a projection layer, a set of `DFineEncoder`, a top-down Feature Pyramid Network
|
|
2027
|
-
(FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://huggingface.co/papers/2304.08069
|
|
2028
|
-
|
|
2029
|
-
Args:
|
|
2030
|
-
config: DFineConfig
|
|
2031
|
-
"""
|
|
2032
|
-
|
|
2033
|
-
def __init__(self, config: DFineConfig):
|
|
2034
|
-
super().__init__()
|
|
2035
|
-
self.config = config
|
|
2036
|
-
self.in_channels = config.encoder_in_channels
|
|
2037
|
-
self.num_fpn_stages = len(self.in_channels) - 1
|
|
2038
|
-
self.feat_strides = config.feat_strides
|
|
2039
|
-
self.encoder_hidden_dim = config.encoder_hidden_dim
|
|
2040
|
-
self.encode_proj_layers = config.encode_proj_layers
|
|
2041
|
-
self.positional_encoding_temperature = config.positional_encoding_temperature
|
|
2042
|
-
self.eval_size = config.eval_size
|
|
2043
|
-
self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
|
|
2044
|
-
self.out_strides = self.feat_strides
|
|
2045
|
-
|
|
2046
|
-
# encoder transformer
|
|
2047
|
-
self.encoder = nn.ModuleList([DFineEncoder(config) for _ in range(len(self.encode_proj_layers))])
|
|
2048
|
-
# top-down fpn
|
|
2049
|
-
self.lateral_convs = nn.ModuleList()
|
|
2050
|
-
self.fpn_blocks = nn.ModuleList()
|
|
2051
|
-
for _ in range(len(self.in_channels) - 1, 0, -1):
|
|
2052
|
-
lateral_layer = DFineConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1)
|
|
2053
|
-
self.lateral_convs.append(lateral_layer)
|
|
2054
|
-
num_blocks = round(3 * config.depth_mult)
|
|
2055
|
-
fpn_layer = DFineRepNCSPELAN4(config, numb_blocks=num_blocks)
|
|
2056
|
-
self.fpn_blocks.append(fpn_layer)
|
|
2057
|
-
|
|
2058
|
-
# bottom-up pan
|
|
2059
|
-
self.downsample_convs = nn.ModuleList()
|
|
2060
|
-
self.pan_blocks = nn.ModuleList()
|
|
2061
|
-
for _ in range(len(self.in_channels) - 1):
|
|
2062
|
-
self.downsample_convs.append(DFineSCDown(config, 3, 2))
|
|
2063
|
-
num_blocks = round(3 * config.depth_mult)
|
|
2064
|
-
self.pan_blocks.append(DFineRepNCSPELAN4(config, numb_blocks=num_blocks))
|
|
2065
|
-
|
|
2066
|
-
@staticmethod
|
|
2067
|
-
def build_2d_sincos_position_embedding(
|
|
2068
|
-
width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
|
|
2069
|
-
):
|
|
2070
|
-
grid_w = torch.arange(torch_int(width), device=device).to(dtype)
|
|
2071
|
-
grid_h = torch.arange(torch_int(height), device=device).to(dtype)
|
|
2072
|
-
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
|
2073
|
-
if embed_dim % 4 != 0:
|
|
2074
|
-
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
|
|
2075
|
-
pos_dim = embed_dim // 4
|
|
2076
|
-
omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
|
|
2077
|
-
omega = 1.0 / (temperature**omega)
|
|
2078
|
-
|
|
2079
|
-
out_w = grid_w.flatten()[..., None] @ omega[None]
|
|
2080
|
-
out_h = grid_h.flatten()[..., None] @ omega[None]
|
|
2081
|
-
|
|
2082
|
-
return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
|
|
2083
|
-
|
|
2084
|
-
def forward(
|
|
2085
|
-
self,
|
|
2086
|
-
inputs_embeds=None,
|
|
2087
|
-
attention_mask=None,
|
|
2088
|
-
position_embeddings=None,
|
|
2089
|
-
spatial_shapes=None,
|
|
2090
|
-
level_start_index=None,
|
|
2091
|
-
valid_ratios=None,
|
|
2092
|
-
output_attentions=None,
|
|
2093
|
-
output_hidden_states=None,
|
|
2094
|
-
return_dict=None,
|
|
2095
|
-
):
|
|
2096
|
-
r"""
|
|
2097
|
-
Args:
|
|
2098
|
-
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
2099
|
-
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
|
2100
|
-
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
2101
|
-
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
|
2102
|
-
- 1 for pixel features that are real (i.e. **not masked**),
|
|
2103
|
-
- 0 for pixel features that are padding (i.e. **masked**).
|
|
2104
|
-
[What are attention masks?](../glossary#attention-mask)
|
|
2105
|
-
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
2106
|
-
Position embeddings that are added to the queries and keys in each self-attention layer.
|
|
2107
|
-
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
|
|
2108
|
-
Spatial shapes of each feature map.
|
|
2109
|
-
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
|
|
2110
|
-
Starting index of each feature map.
|
|
2111
|
-
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
|
|
2112
|
-
Ratio of valid area in each feature level.
|
|
2113
|
-
output_attentions (`bool`, *optional*):
|
|
2114
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
2115
|
-
returned tensors for more detail.
|
|
2116
|
-
output_hidden_states (`bool`, *optional*):
|
|
2117
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
2118
|
-
for more detail.
|
|
2119
|
-
return_dict (`bool`, *optional*):
|
|
2120
|
-
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
2121
|
-
"""
|
|
2122
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
2123
|
-
output_hidden_states = (
|
|
2124
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
2125
|
-
)
|
|
2126
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
2127
|
-
|
|
2128
|
-
hidden_states = inputs_embeds
|
|
2129
|
-
|
|
2130
|
-
encoder_states = () if output_hidden_states else None
|
|
2131
|
-
all_attentions = () if output_attentions else None
|
|
2132
|
-
|
|
2133
|
-
# encoder
|
|
2134
|
-
if self.config.encoder_layers > 0:
|
|
2135
|
-
for i, enc_ind in enumerate(self.encode_proj_layers):
|
|
2136
|
-
if output_hidden_states:
|
|
2137
|
-
encoder_states = encoder_states + (hidden_states[enc_ind],)
|
|
2138
|
-
height, width = hidden_states[enc_ind].shape[2:]
|
|
2139
|
-
# flatten [batch, channel, height, width] to [batch, height*width, channel]
|
|
2140
|
-
src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
|
|
2141
|
-
if self.training or self.eval_size is None:
|
|
2142
|
-
pos_embed = self.build_2d_sincos_position_embedding(
|
|
2143
|
-
width,
|
|
2144
|
-
height,
|
|
2145
|
-
self.encoder_hidden_dim,
|
|
2146
|
-
self.positional_encoding_temperature,
|
|
2147
|
-
device=src_flatten.device,
|
|
2148
|
-
dtype=src_flatten.dtype,
|
|
2149
|
-
)
|
|
2150
|
-
else:
|
|
2151
|
-
pos_embed = None
|
|
2152
|
-
|
|
2153
|
-
layer_outputs = self.encoder[i](
|
|
2154
|
-
src_flatten,
|
|
2155
|
-
pos_embed=pos_embed,
|
|
2156
|
-
output_attentions=output_attentions,
|
|
2157
|
-
)
|
|
2158
|
-
hidden_states[enc_ind] = (
|
|
2159
|
-
layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
|
|
2160
|
-
)
|
|
2161
|
-
|
|
2162
|
-
if output_attentions:
|
|
2163
|
-
all_attentions = all_attentions + (layer_outputs[1],)
|
|
2164
|
-
|
|
2165
|
-
if output_hidden_states:
|
|
2166
|
-
encoder_states = encoder_states + (hidden_states[enc_ind],)
|
|
2167
|
-
|
|
2168
|
-
# top-down FPN
|
|
2169
|
-
fpn_feature_maps = [hidden_states[-1]]
|
|
2170
|
-
for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
|
|
2171
|
-
backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1]
|
|
2172
|
-
top_fpn_feature_map = fpn_feature_maps[-1]
|
|
2173
|
-
# apply lateral block
|
|
2174
|
-
top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
|
|
2175
|
-
fpn_feature_maps[-1] = top_fpn_feature_map
|
|
2176
|
-
# apply fpn block
|
|
2177
|
-
top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
|
|
2178
|
-
fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
|
|
2179
|
-
new_fpn_feature_map = fpn_block(fused_feature_map)
|
|
2180
|
-
fpn_feature_maps.append(new_fpn_feature_map)
|
|
2181
|
-
|
|
2182
|
-
fpn_feature_maps.reverse()
|
|
2183
|
-
|
|
2184
|
-
# bottom-up PAN
|
|
2185
|
-
pan_feature_maps = [fpn_feature_maps[0]]
|
|
2186
|
-
for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
|
|
2187
|
-
top_pan_feature_map = pan_feature_maps[-1]
|
|
2188
|
-
fpn_feature_map = fpn_feature_maps[idx + 1]
|
|
2189
|
-
downsampled_feature_map = downsample_conv(top_pan_feature_map)
|
|
2190
|
-
fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
|
|
2191
|
-
new_pan_feature_map = pan_block(fused_feature_map)
|
|
2192
|
-
pan_feature_maps.append(new_pan_feature_map)
|
|
2193
|
-
|
|
2194
|
-
if not return_dict:
|
|
2195
|
-
return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None)
|
|
2196
|
-
return BaseModelOutput(
|
|
2197
|
-
last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions
|
|
2198
|
-
)
|
|
2199
|
-
|
|
2200
|
-
|
|
2201
2061
|
__all__ = ["DFineModel", "DFinePreTrainedModel", "DFineForObjectDetection"]
|