transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
2
|
+
# This file was automatically generated from src/transformers/models/conditional_detr/modular_conditional_detr.py.
|
|
3
|
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
4
|
+
# the file from the modular. If any change should be done, please apply the change to the
|
|
5
|
+
# modular_conditional_detr.py file directly. One of our CI enforces this.
|
|
6
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
1
7
|
# Copyright 2022 Microsoft Research Asia and The HuggingFace Inc. team. All rights reserved.
|
|
2
8
|
#
|
|
3
9
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -11,39 +17,33 @@
|
|
|
11
17
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
18
|
# See the License for the specific language governing permissions and
|
|
13
19
|
# limitations under the License.
|
|
14
|
-
"""PyTorch Conditional DETR model."""
|
|
15
|
-
|
|
16
20
|
import math
|
|
21
|
+
from collections.abc import Callable
|
|
17
22
|
from dataclasses import dataclass
|
|
18
23
|
|
|
19
24
|
import torch
|
|
20
|
-
from torch import
|
|
25
|
+
from torch import nn
|
|
21
26
|
|
|
22
27
|
from ... import initialization as init
|
|
23
28
|
from ...activations import ACT2FN
|
|
24
|
-
from ...
|
|
29
|
+
from ...backbone_utils import load_backbone
|
|
30
|
+
from ...masking_utils import create_bidirectional_mask
|
|
25
31
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
26
32
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
|
|
27
|
-
from ...modeling_utils import PreTrainedModel
|
|
28
|
-
from ...
|
|
29
|
-
from ...
|
|
33
|
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
34
|
+
from ...processing_utils import Unpack
|
|
35
|
+
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
36
|
+
from ...utils import ModelOutput, TransformersKwargs, auto_docstring
|
|
37
|
+
from ...utils.generic import OutputRecorder, can_return_tuple, check_model_inputs
|
|
30
38
|
from .configuration_conditional_detr import ConditionalDetrConfig
|
|
31
39
|
|
|
32
40
|
|
|
33
|
-
if is_timm_available():
|
|
34
|
-
from timm import create_model
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
logger = logging.get_logger(__name__)
|
|
38
|
-
|
|
39
|
-
|
|
40
41
|
@dataclass
|
|
41
42
|
@auto_docstring(
|
|
42
43
|
custom_intro="""
|
|
43
|
-
Base class for outputs of the
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
decoding losses.
|
|
44
|
+
Base class for outputs of the CONDITIONAL_DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
|
|
45
|
+
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
|
|
46
|
+
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
|
|
47
47
|
"""
|
|
48
48
|
)
|
|
49
49
|
class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
|
|
@@ -60,16 +60,16 @@ class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
|
|
|
60
60
|
"""
|
|
61
61
|
|
|
62
62
|
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
63
|
+
|
|
63
64
|
reference_points: tuple[torch.FloatTensor] | None = None
|
|
64
65
|
|
|
65
66
|
|
|
66
67
|
@dataclass
|
|
67
68
|
@auto_docstring(
|
|
68
69
|
custom_intro="""
|
|
69
|
-
Base class for outputs of the
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
losses.
|
|
70
|
+
Base class for outputs of the CONDITIONAL_DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
|
|
71
|
+
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
|
|
72
|
+
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
|
|
73
73
|
"""
|
|
74
74
|
)
|
|
75
75
|
class ConditionalDetrModelOutput(Seq2SeqModelOutput):
|
|
@@ -84,6 +84,7 @@ class ConditionalDetrModelOutput(Seq2SeqModelOutput):
|
|
|
84
84
|
"""
|
|
85
85
|
|
|
86
86
|
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
87
|
+
|
|
87
88
|
reference_points: tuple[torch.FloatTensor] | None = None
|
|
88
89
|
|
|
89
90
|
|
|
@@ -93,7 +94,6 @@ class ConditionalDetrModelOutput(Seq2SeqModelOutput):
|
|
|
93
94
|
Output type of [`ConditionalDetrForObjectDetection`].
|
|
94
95
|
"""
|
|
95
96
|
)
|
|
96
|
-
# Copied from transformers.models.detr.modeling_detr.DetrObjectDetectionOutput with Detr->ConditionalDetr
|
|
97
97
|
class ConditionalDetrObjectDetectionOutput(ModelOutput):
|
|
98
98
|
r"""
|
|
99
99
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
|
@@ -137,7 +137,6 @@ class ConditionalDetrObjectDetectionOutput(ModelOutput):
|
|
|
137
137
|
Output type of [`ConditionalDetrForSegmentation`].
|
|
138
138
|
"""
|
|
139
139
|
)
|
|
140
|
-
# Copied from transformers.models.detr.modeling_detr.DetrSegmentationOutput with Detr->ConditionalDetr
|
|
141
140
|
class ConditionalDetrSegmentationOutput(ModelOutput):
|
|
142
141
|
r"""
|
|
143
142
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
|
@@ -182,7 +181,6 @@ class ConditionalDetrSegmentationOutput(ModelOutput):
|
|
|
182
181
|
encoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
183
182
|
|
|
184
183
|
|
|
185
|
-
# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->ConditionalDetr
|
|
186
184
|
class ConditionalDetrFrozenBatchNorm2d(nn.Module):
|
|
187
185
|
"""
|
|
188
186
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
@@ -222,7 +220,6 @@ class ConditionalDetrFrozenBatchNorm2d(nn.Module):
|
|
|
222
220
|
return x * scale + bias
|
|
223
221
|
|
|
224
222
|
|
|
225
|
-
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->ConditionalDetr
|
|
226
223
|
def replace_batch_norm(model):
|
|
227
224
|
r"""
|
|
228
225
|
Recursively replace all `torch.nn.BatchNorm2d` with `ConditionalDetrFrozenBatchNorm2d`.
|
|
@@ -247,7 +244,6 @@ def replace_batch_norm(model):
|
|
|
247
244
|
replace_batch_norm(module)
|
|
248
245
|
|
|
249
246
|
|
|
250
|
-
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->ConditionalDetr
|
|
251
247
|
class ConditionalDetrConvEncoder(nn.Module):
|
|
252
248
|
"""
|
|
253
249
|
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
|
|
@@ -261,47 +257,25 @@ class ConditionalDetrConvEncoder(nn.Module):
|
|
|
261
257
|
|
|
262
258
|
self.config = config
|
|
263
259
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
# We default to values which were previously hard-coded. This enables configurability from the config
|
|
267
|
-
# using backbone arguments, while keeping the default behavior the same.
|
|
268
|
-
requires_backends(self, ["timm"])
|
|
269
|
-
kwargs = getattr(config, "backbone_kwargs", {})
|
|
270
|
-
kwargs = {} if kwargs is None else kwargs.copy()
|
|
271
|
-
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
|
|
272
|
-
num_channels = kwargs.pop("in_chans", config.num_channels)
|
|
273
|
-
if config.dilation:
|
|
274
|
-
kwargs["output_stride"] = kwargs.get("output_stride", 16)
|
|
275
|
-
backbone = create_model(
|
|
276
|
-
config.backbone,
|
|
277
|
-
pretrained=config.use_pretrained_backbone,
|
|
278
|
-
features_only=True,
|
|
279
|
-
out_indices=out_indices,
|
|
280
|
-
in_chans=num_channels,
|
|
281
|
-
**kwargs,
|
|
282
|
-
)
|
|
283
|
-
else:
|
|
284
|
-
backbone = load_backbone(config)
|
|
260
|
+
backbone = load_backbone(config)
|
|
261
|
+
self.intermediate_channel_sizes = backbone.channels
|
|
285
262
|
|
|
286
263
|
# replace batch norm by frozen batch norm
|
|
287
264
|
with torch.no_grad():
|
|
288
265
|
replace_batch_norm(backbone)
|
|
289
|
-
self.model = backbone
|
|
290
|
-
self.intermediate_channel_sizes = (
|
|
291
|
-
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
|
|
292
|
-
)
|
|
293
266
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
267
|
+
# We used to load with timm library directly instead of the AutoBackbone API
|
|
268
|
+
# so we need to unwrap the `backbone._backbone` module to load weights without mismatch
|
|
269
|
+
is_timm_model = False
|
|
270
|
+
if hasattr(backbone, "_backbone"):
|
|
271
|
+
backbone = backbone._backbone
|
|
272
|
+
is_timm_model = True
|
|
273
|
+
self.model = backbone
|
|
301
274
|
|
|
275
|
+
backbone_model_type = config.backbone_config.model_type
|
|
302
276
|
if "resnet" in backbone_model_type:
|
|
303
277
|
for name, parameter in self.model.named_parameters():
|
|
304
|
-
if
|
|
278
|
+
if is_timm_model:
|
|
305
279
|
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
|
306
280
|
parameter.requires_grad_(False)
|
|
307
281
|
else:
|
|
@@ -310,7 +284,9 @@ class ConditionalDetrConvEncoder(nn.Module):
|
|
|
310
284
|
|
|
311
285
|
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
|
312
286
|
# send pixel_values through the model to get list of feature maps
|
|
313
|
-
features = self.model(pixel_values)
|
|
287
|
+
features = self.model(pixel_values)
|
|
288
|
+
if isinstance(features, dict):
|
|
289
|
+
features = features.feature_maps
|
|
314
290
|
|
|
315
291
|
out = []
|
|
316
292
|
for feature_map in features:
|
|
@@ -320,66 +296,58 @@ class ConditionalDetrConvEncoder(nn.Module):
|
|
|
320
296
|
return out
|
|
321
297
|
|
|
322
298
|
|
|
323
|
-
# Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->ConditionalDetr
|
|
324
|
-
class ConditionalDetrConvModel(nn.Module):
|
|
325
|
-
"""
|
|
326
|
-
This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
|
|
327
|
-
"""
|
|
328
|
-
|
|
329
|
-
def __init__(self, conv_encoder, position_embedding):
|
|
330
|
-
super().__init__()
|
|
331
|
-
self.conv_encoder = conv_encoder
|
|
332
|
-
self.position_embedding = position_embedding
|
|
333
|
-
|
|
334
|
-
def forward(self, pixel_values, pixel_mask):
|
|
335
|
-
# send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
|
|
336
|
-
out = self.conv_encoder(pixel_values, pixel_mask)
|
|
337
|
-
pos = []
|
|
338
|
-
for feature_map, mask in out:
|
|
339
|
-
# position encoding
|
|
340
|
-
pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
|
|
341
|
-
|
|
342
|
-
return out, pos
|
|
343
|
-
|
|
344
|
-
|
|
345
299
|
class ConditionalDetrSinePositionEmbedding(nn.Module):
|
|
346
300
|
"""
|
|
347
301
|
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
|
|
348
302
|
need paper, generalized to work on images.
|
|
349
303
|
"""
|
|
350
304
|
|
|
351
|
-
def __init__(
|
|
305
|
+
def __init__(
|
|
306
|
+
self,
|
|
307
|
+
num_position_features: int = 64,
|
|
308
|
+
temperature: int = 10000,
|
|
309
|
+
normalize: bool = False,
|
|
310
|
+
scale: float | None = None,
|
|
311
|
+
):
|
|
352
312
|
super().__init__()
|
|
353
|
-
self.embedding_dim = embedding_dim
|
|
354
|
-
self.temperature = temperature
|
|
355
|
-
self.normalize = normalize
|
|
356
313
|
if scale is not None and normalize is False:
|
|
357
314
|
raise ValueError("normalize should be True if scale is passed")
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
self.
|
|
315
|
+
self.num_position_features = num_position_features
|
|
316
|
+
self.temperature = temperature
|
|
317
|
+
self.normalize = normalize
|
|
318
|
+
self.scale = 2 * math.pi if scale is None else scale
|
|
361
319
|
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
320
|
+
@compile_compatible_method_lru_cache(maxsize=1)
|
|
321
|
+
def forward(
|
|
322
|
+
self,
|
|
323
|
+
shape: torch.Size,
|
|
324
|
+
device: torch.device | str,
|
|
325
|
+
dtype: torch.dtype,
|
|
326
|
+
mask: torch.Tensor | None = None,
|
|
327
|
+
) -> torch.Tensor:
|
|
328
|
+
if mask is None:
|
|
329
|
+
mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
|
|
330
|
+
y_embed = mask.cumsum(1, dtype=dtype)
|
|
331
|
+
x_embed = mask.cumsum(2, dtype=dtype)
|
|
367
332
|
if self.normalize:
|
|
368
|
-
|
|
369
|
-
|
|
333
|
+
eps = 1e-6
|
|
334
|
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
335
|
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
370
336
|
|
|
371
|
-
dim_t = torch.arange(self.
|
|
372
|
-
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.
|
|
337
|
+
dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
|
|
338
|
+
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
|
|
373
339
|
|
|
374
340
|
pos_x = x_embed[:, :, :, None] / dim_t
|
|
375
341
|
pos_y = y_embed[:, :, :, None] / dim_t
|
|
376
342
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
377
343
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
378
344
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
345
|
+
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
|
|
346
|
+
# expected by the encoder
|
|
347
|
+
pos = pos.flatten(2).permute(0, 2, 1)
|
|
379
348
|
return pos
|
|
380
349
|
|
|
381
350
|
|
|
382
|
-
# Copied from transformers.models.detr.modeling_detr.DetrLearnedPositionEmbedding with Detr->ConditionalDetr
|
|
383
351
|
class ConditionalDetrLearnedPositionEmbedding(nn.Module):
|
|
384
352
|
"""
|
|
385
353
|
This module learns positional embeddings up to a fixed maximum size.
|
|
@@ -390,354 +358,385 @@ class ConditionalDetrLearnedPositionEmbedding(nn.Module):
|
|
|
390
358
|
self.row_embeddings = nn.Embedding(50, embedding_dim)
|
|
391
359
|
self.column_embeddings = nn.Embedding(50, embedding_dim)
|
|
392
360
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
361
|
+
@compile_compatible_method_lru_cache(maxsize=1)
|
|
362
|
+
def forward(
|
|
363
|
+
self,
|
|
364
|
+
shape: torch.Size,
|
|
365
|
+
device: torch.device | str,
|
|
366
|
+
dtype: torch.dtype,
|
|
367
|
+
mask: torch.Tensor | None = None,
|
|
368
|
+
):
|
|
369
|
+
height, width = shape[-2:]
|
|
370
|
+
width_values = torch.arange(width, device=device)
|
|
371
|
+
height_values = torch.arange(height, device=device)
|
|
397
372
|
x_emb = self.column_embeddings(width_values)
|
|
398
373
|
y_emb = self.row_embeddings(height_values)
|
|
399
374
|
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
|
|
400
375
|
pos = pos.permute(2, 0, 1)
|
|
401
376
|
pos = pos.unsqueeze(0)
|
|
402
|
-
pos = pos.repeat(
|
|
377
|
+
pos = pos.repeat(shape[0], 1, 1, 1)
|
|
378
|
+
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
|
|
379
|
+
# expected by the encoder
|
|
380
|
+
pos = pos.flatten(2).permute(0, 2, 1)
|
|
403
381
|
return pos
|
|
404
382
|
|
|
405
383
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
384
|
+
def eager_attention_forward(
|
|
385
|
+
module: nn.Module,
|
|
386
|
+
query: torch.Tensor,
|
|
387
|
+
key: torch.Tensor,
|
|
388
|
+
value: torch.Tensor,
|
|
389
|
+
attention_mask: torch.Tensor | None,
|
|
390
|
+
scaling: float | None = None,
|
|
391
|
+
dropout: float = 0.0,
|
|
392
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
393
|
+
):
|
|
394
|
+
if scaling is None:
|
|
395
|
+
scaling = query.size(-1) ** -0.5
|
|
416
396
|
|
|
417
|
-
|
|
397
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
398
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
418
399
|
|
|
400
|
+
if attention_mask is not None:
|
|
401
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
402
|
+
attn_weights = attn_weights + attention_mask
|
|
419
403
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
scale = 2 * math.pi
|
|
423
|
-
dim = d_model // 2
|
|
424
|
-
dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
|
|
425
|
-
dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
|
|
426
|
-
x_embed = pos_tensor[:, :, 0] * scale
|
|
427
|
-
y_embed = pos_tensor[:, :, 1] * scale
|
|
428
|
-
pos_x = x_embed[:, :, None] / dim_t
|
|
429
|
-
pos_y = y_embed[:, :, None] / dim_t
|
|
430
|
-
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
431
|
-
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
432
|
-
pos = torch.cat((pos_y, pos_x), dim=2)
|
|
433
|
-
return pos.to(pos_tensor.dtype)
|
|
404
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
405
|
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
434
406
|
|
|
407
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
408
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
435
409
|
|
|
436
|
-
|
|
437
|
-
x = x.clamp(min=0, max=1)
|
|
438
|
-
x1 = x.clamp(min=eps)
|
|
439
|
-
x2 = (1 - x).clamp(min=eps)
|
|
440
|
-
return torch.log(x1 / x2)
|
|
410
|
+
return attn_output, attn_weights
|
|
441
411
|
|
|
442
412
|
|
|
443
|
-
|
|
444
|
-
class DetrAttention(nn.Module):
|
|
413
|
+
class ConditionalDetrSelfAttention(nn.Module):
|
|
445
414
|
"""
|
|
446
|
-
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
415
|
+
Multi-headed self-attention from 'Attention Is All You Need' paper.
|
|
447
416
|
|
|
448
|
-
|
|
417
|
+
In CONDITIONAL_DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
|
|
449
418
|
"""
|
|
450
419
|
|
|
451
420
|
def __init__(
|
|
452
421
|
self,
|
|
453
|
-
|
|
454
|
-
|
|
422
|
+
config: ConditionalDetrConfig,
|
|
423
|
+
hidden_size: int,
|
|
424
|
+
num_attention_heads: int,
|
|
455
425
|
dropout: float = 0.0,
|
|
456
426
|
bias: bool = True,
|
|
457
427
|
):
|
|
458
428
|
super().__init__()
|
|
459
|
-
self.
|
|
460
|
-
self.
|
|
461
|
-
self.dropout = dropout
|
|
462
|
-
self.head_dim = embed_dim // num_heads
|
|
463
|
-
if self.head_dim * num_heads != self.embed_dim:
|
|
464
|
-
raise ValueError(
|
|
465
|
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
466
|
-
f" {num_heads})."
|
|
467
|
-
)
|
|
429
|
+
self.config = config
|
|
430
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
468
431
|
self.scaling = self.head_dim**-0.5
|
|
432
|
+
self.attention_dropout = dropout
|
|
433
|
+
self.is_causal = False
|
|
469
434
|
|
|
470
|
-
self.k_proj = nn.Linear(
|
|
471
|
-
self.v_proj = nn.Linear(
|
|
472
|
-
self.q_proj = nn.Linear(
|
|
473
|
-
self.
|
|
474
|
-
|
|
475
|
-
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
|
476
|
-
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
477
|
-
|
|
478
|
-
def with_pos_embed(self, tensor: torch.Tensor, object_queries: Tensor | None):
|
|
479
|
-
return tensor if object_queries is None else tensor + object_queries
|
|
435
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
436
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
437
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
438
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
480
439
|
|
|
481
440
|
def forward(
|
|
482
441
|
self,
|
|
483
442
|
hidden_states: torch.Tensor,
|
|
484
443
|
attention_mask: torch.Tensor | None = None,
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
"""
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
is_cross_attention = key_value_states is not None
|
|
494
|
-
batch_size, target_len, embed_dim = hidden_states.size()
|
|
495
|
-
|
|
496
|
-
# add position embeddings to the hidden states before projecting to queries and keys
|
|
497
|
-
if object_queries is not None:
|
|
498
|
-
hidden_states_original = hidden_states
|
|
499
|
-
hidden_states = self.with_pos_embed(hidden_states, object_queries)
|
|
500
|
-
|
|
501
|
-
# add key-value position embeddings to the key value states
|
|
502
|
-
if spatial_position_embeddings is not None:
|
|
503
|
-
key_value_states_original = key_value_states
|
|
504
|
-
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
|
|
505
|
-
|
|
506
|
-
# get query proj
|
|
507
|
-
query_states = self.q_proj(hidden_states) * self.scaling
|
|
508
|
-
# get key, value proj
|
|
509
|
-
if is_cross_attention:
|
|
510
|
-
# cross_attentions
|
|
511
|
-
key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
|
|
512
|
-
value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
|
|
513
|
-
else:
|
|
514
|
-
# self_attention
|
|
515
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
|
|
516
|
-
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
|
|
444
|
+
position_embeddings: torch.Tensor | None = None,
|
|
445
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
446
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
447
|
+
"""
|
|
448
|
+
Position embeddings are added to both queries and keys (but not values).
|
|
449
|
+
"""
|
|
450
|
+
input_shape = hidden_states.shape[:-1]
|
|
451
|
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
517
452
|
|
|
518
|
-
|
|
519
|
-
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
|
|
520
|
-
key_states = key_states.view(*proj_shape)
|
|
521
|
-
value_states = value_states.view(*proj_shape)
|
|
453
|
+
query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
|
|
522
454
|
|
|
523
|
-
|
|
455
|
+
query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
456
|
+
key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
457
|
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
524
458
|
|
|
525
|
-
|
|
459
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
460
|
+
self.config._attn_implementation, eager_attention_forward
|
|
461
|
+
)
|
|
526
462
|
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
463
|
+
attn_output, attn_weights = attention_interface(
|
|
464
|
+
self,
|
|
465
|
+
query_states,
|
|
466
|
+
key_states,
|
|
467
|
+
value_states,
|
|
468
|
+
attention_mask,
|
|
469
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
470
|
+
scaling=self.scaling,
|
|
471
|
+
**kwargs,
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
475
|
+
attn_output = self.o_proj(attn_output)
|
|
476
|
+
return attn_output, attn_weights
|
|
532
477
|
|
|
533
|
-
if attention_mask is not None:
|
|
534
|
-
if attention_mask.size() != (batch_size, 1, target_len, source_len):
|
|
535
|
-
raise ValueError(
|
|
536
|
-
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
|
537
|
-
f" {attention_mask.size()}"
|
|
538
|
-
)
|
|
539
|
-
if attention_mask.dtype == torch.bool:
|
|
540
|
-
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
|
|
541
|
-
attention_mask, -torch.inf
|
|
542
|
-
)
|
|
543
|
-
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
|
544
|
-
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
|
545
|
-
|
|
546
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
547
|
-
|
|
548
|
-
if output_attentions:
|
|
549
|
-
# this operation is a bit awkward, but it's required to
|
|
550
|
-
# make sure that attn_weights keeps its gradient.
|
|
551
|
-
# In order to do so, attn_weights have to reshaped
|
|
552
|
-
# twice and have to be reused in the following
|
|
553
|
-
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
|
554
|
-
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
|
555
|
-
else:
|
|
556
|
-
attn_weights_reshaped = None
|
|
557
478
|
|
|
558
|
-
|
|
479
|
+
class ConditionalDetrDecoderSelfAttention(nn.Module):
|
|
480
|
+
"""
|
|
481
|
+
Multi-headed self-attention for Conditional DETR decoder layers.
|
|
559
482
|
|
|
560
|
-
|
|
483
|
+
This attention module handles separate content and position projections, which are then combined
|
|
484
|
+
before applying standard self-attention. Position embeddings are added to both queries and keys.
|
|
485
|
+
"""
|
|
561
486
|
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
487
|
+
def __init__(
|
|
488
|
+
self,
|
|
489
|
+
config: ConditionalDetrConfig,
|
|
490
|
+
hidden_size: int,
|
|
491
|
+
num_attention_heads: int,
|
|
492
|
+
dropout: float = 0.0,
|
|
493
|
+
):
|
|
494
|
+
super().__init__()
|
|
495
|
+
self.config = config
|
|
496
|
+
self.hidden_size = hidden_size
|
|
497
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
498
|
+
self.scaling = self.head_dim**-0.5
|
|
499
|
+
self.attention_dropout = dropout
|
|
500
|
+
self.is_causal = False
|
|
501
|
+
|
|
502
|
+
# Content and position projections
|
|
503
|
+
self.q_content_proj = nn.Linear(hidden_size, hidden_size)
|
|
504
|
+
self.q_pos_proj = nn.Linear(hidden_size, hidden_size)
|
|
505
|
+
self.k_content_proj = nn.Linear(hidden_size, hidden_size)
|
|
506
|
+
self.k_pos_proj = nn.Linear(hidden_size, hidden_size)
|
|
507
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size)
|
|
508
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size)
|
|
509
|
+
|
|
510
|
+
def forward(
|
|
511
|
+
self,
|
|
512
|
+
hidden_states: torch.Tensor,
|
|
513
|
+
query_position_embeddings: torch.Tensor,
|
|
514
|
+
attention_mask: torch.Tensor | None = None,
|
|
515
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
516
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
517
|
+
"""
|
|
518
|
+
Args:
|
|
519
|
+
hidden_states (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
520
|
+
Input hidden states from the decoder layer.
|
|
521
|
+
query_position_embeddings (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
522
|
+
Position embeddings for queries and keys. Required (unlike standard attention). Processed through
|
|
523
|
+
separate position projections (`q_pos_proj`, `k_pos_proj`) and added to content projections.
|
|
524
|
+
attention_mask (`torch.Tensor` of shape `(batch_size, 1, num_queries, num_queries)`, *optional*):
|
|
525
|
+
Attention mask to avoid attending to padding tokens.
|
|
526
|
+
"""
|
|
527
|
+
input_shape = hidden_states.shape[:-1]
|
|
528
|
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
529
|
+
|
|
530
|
+
query_states = (
|
|
531
|
+
(self.q_content_proj(hidden_states) + self.q_pos_proj(query_position_embeddings))
|
|
532
|
+
.view(hidden_shape)
|
|
533
|
+
.transpose(1, 2)
|
|
534
|
+
)
|
|
535
|
+
key_states = (
|
|
536
|
+
(self.k_content_proj(hidden_states) + self.k_pos_proj(query_position_embeddings))
|
|
537
|
+
.view(hidden_shape)
|
|
538
|
+
.transpose(1, 2)
|
|
539
|
+
)
|
|
540
|
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
567
541
|
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
542
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
543
|
+
self.config._attn_implementation, eager_attention_forward
|
|
544
|
+
)
|
|
571
545
|
|
|
572
|
-
attn_output =
|
|
546
|
+
attn_output, attn_weights = attention_interface(
|
|
547
|
+
self,
|
|
548
|
+
query_states,
|
|
549
|
+
key_states,
|
|
550
|
+
value_states,
|
|
551
|
+
attention_mask,
|
|
552
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
553
|
+
scaling=self.scaling,
|
|
554
|
+
**kwargs,
|
|
555
|
+
)
|
|
573
556
|
|
|
574
|
-
|
|
557
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
558
|
+
attn_output = self.o_proj(attn_output)
|
|
559
|
+
return attn_output, attn_weights
|
|
575
560
|
|
|
576
561
|
|
|
577
|
-
class
|
|
562
|
+
class ConditionalDetrDecoderCrossAttention(nn.Module):
|
|
578
563
|
"""
|
|
579
|
-
|
|
564
|
+
Multi-headed cross-attention for Conditional DETR decoder layers.
|
|
580
565
|
|
|
581
|
-
|
|
582
|
-
|
|
566
|
+
This attention module handles the special cross-attention logic in Conditional DETR:
|
|
567
|
+
- Separate content and position projections for queries and keys
|
|
568
|
+
- Concatenation of query sine embeddings with queries (doubling query dimension)
|
|
569
|
+
- Concatenation of key position embeddings with keys (doubling key dimension)
|
|
570
|
+
- Output dimension remains hidden_size despite doubled input dimensions
|
|
583
571
|
"""
|
|
584
572
|
|
|
585
573
|
def __init__(
|
|
586
574
|
self,
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
575
|
+
config: ConditionalDetrConfig,
|
|
576
|
+
hidden_size: int,
|
|
577
|
+
num_attention_heads: int,
|
|
590
578
|
dropout: float = 0.0,
|
|
591
|
-
bias: bool = True,
|
|
592
579
|
):
|
|
593
580
|
super().__init__()
|
|
594
|
-
self.
|
|
595
|
-
self.
|
|
596
|
-
self.
|
|
597
|
-
self.
|
|
598
|
-
self.
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
self.
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
self.
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
def _v_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
|
618
|
-
return tensor.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous()
|
|
581
|
+
self.config = config
|
|
582
|
+
self.hidden_size = hidden_size
|
|
583
|
+
self.num_attention_heads = num_attention_heads
|
|
584
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
585
|
+
self.attention_dropout = dropout
|
|
586
|
+
self.is_causal = False
|
|
587
|
+
|
|
588
|
+
# Content and position projections
|
|
589
|
+
self.q_content_proj = nn.Linear(hidden_size, hidden_size)
|
|
590
|
+
self.q_pos_proj = nn.Linear(hidden_size, hidden_size)
|
|
591
|
+
self.k_content_proj = nn.Linear(hidden_size, hidden_size)
|
|
592
|
+
self.k_pos_proj = nn.Linear(hidden_size, hidden_size)
|
|
593
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size)
|
|
594
|
+
self.q_pos_sine_proj = nn.Linear(hidden_size, hidden_size)
|
|
595
|
+
|
|
596
|
+
# Output projection: input is hidden_size * 2 (from concatenated q/k), output is hidden_size
|
|
597
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size)
|
|
598
|
+
|
|
599
|
+
# Compute scaling for expanded head_dim (q and k have doubled dimensions after concatenation)
|
|
600
|
+
# This matches the original Conditional DETR implementation where embed_dim * 2 is used
|
|
601
|
+
expanded_head_dim = (hidden_size * 2) // num_attention_heads
|
|
602
|
+
self.scaling = expanded_head_dim**-0.5
|
|
619
603
|
|
|
620
604
|
def forward(
|
|
621
605
|
self,
|
|
622
606
|
hidden_states: torch.Tensor,
|
|
607
|
+
encoder_hidden_states: torch.Tensor,
|
|
608
|
+
query_sine_embed: torch.Tensor,
|
|
609
|
+
encoder_position_embeddings: torch.Tensor,
|
|
610
|
+
query_position_embeddings: torch.Tensor | None = None,
|
|
623
611
|
attention_mask: torch.Tensor | None = None,
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
else:
|
|
677
|
-
attn_weights_reshaped = None
|
|
612
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
613
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
614
|
+
"""
|
|
615
|
+
Args:
|
|
616
|
+
hidden_states (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
617
|
+
Decoder hidden states (queries).
|
|
618
|
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, encoder_seq_len, hidden_size)`):
|
|
619
|
+
Encoder output hidden states (keys and values).
|
|
620
|
+
query_sine_embed (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
621
|
+
Sine position embeddings for queries. **Concatenated** (not added) with query content,
|
|
622
|
+
doubling the query dimension.
|
|
623
|
+
encoder_position_embeddings (`torch.Tensor` of shape `(batch_size, encoder_seq_len, hidden_size)`):
|
|
624
|
+
Position embeddings for keys. **Concatenated** (not added) with key content, doubling the key dimension.
|
|
625
|
+
query_position_embeddings (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
626
|
+
Additional position embeddings. When provided (first layer only), **added** to query content
|
|
627
|
+
before concatenation with `query_sine_embed`. Also causes `encoder_position_embeddings` to be
|
|
628
|
+
added to key content before concatenation.
|
|
629
|
+
attention_mask (`torch.Tensor` of shape `(batch_size, 1, num_queries, encoder_seq_len)`, *optional*):
|
|
630
|
+
Attention mask to avoid attending to padding tokens.
|
|
631
|
+
"""
|
|
632
|
+
query_input_shape = hidden_states.shape[:-1]
|
|
633
|
+
kv_input_shape = encoder_hidden_states.shape[:-1]
|
|
634
|
+
query_hidden_shape = (*query_input_shape, self.num_attention_heads, self.head_dim)
|
|
635
|
+
kv_hidden_shape = (*kv_input_shape, self.num_attention_heads, self.head_dim)
|
|
636
|
+
|
|
637
|
+
# Apply content and position projections
|
|
638
|
+
query_input = self.q_content_proj(hidden_states)
|
|
639
|
+
key_input = self.k_content_proj(encoder_hidden_states)
|
|
640
|
+
value_states = self.v_proj(encoder_hidden_states)
|
|
641
|
+
key_pos = self.k_pos_proj(encoder_position_embeddings)
|
|
642
|
+
|
|
643
|
+
# Combine content and position embeddings
|
|
644
|
+
if query_position_embeddings is not None:
|
|
645
|
+
query_input = query_input + self.q_pos_proj(query_position_embeddings)
|
|
646
|
+
key_input = key_input + key_pos
|
|
647
|
+
|
|
648
|
+
# Reshape and concatenate position embeddings (doubling head_dim)
|
|
649
|
+
query_input = query_input.view(query_hidden_shape)
|
|
650
|
+
key_input = key_input.view(kv_hidden_shape)
|
|
651
|
+
query_sine_embed = self.q_pos_sine_proj(query_sine_embed).view(query_hidden_shape)
|
|
652
|
+
key_pos = key_pos.view(kv_hidden_shape)
|
|
653
|
+
|
|
654
|
+
query_states = torch.cat([query_input, query_sine_embed], dim=-1).view(*query_input_shape, -1)
|
|
655
|
+
key_states = torch.cat([key_input, key_pos], dim=-1).view(*kv_input_shape, -1)
|
|
656
|
+
|
|
657
|
+
# Reshape for attention computation
|
|
658
|
+
expanded_head_dim = query_states.shape[-1] // self.num_attention_heads
|
|
659
|
+
query_states = query_states.view(*query_input_shape, self.num_attention_heads, expanded_head_dim).transpose(
|
|
660
|
+
1, 2
|
|
661
|
+
)
|
|
662
|
+
key_states = key_states.view(*kv_input_shape, self.num_attention_heads, expanded_head_dim).transpose(1, 2)
|
|
663
|
+
value_states = value_states.view(kv_hidden_shape).transpose(1, 2)
|
|
678
664
|
|
|
679
|
-
|
|
665
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
666
|
+
self.config._attn_implementation, eager_attention_forward
|
|
667
|
+
)
|
|
680
668
|
|
|
681
|
-
attn_output =
|
|
669
|
+
attn_output, attn_weights = attention_interface(
|
|
670
|
+
self,
|
|
671
|
+
query_states,
|
|
672
|
+
key_states,
|
|
673
|
+
value_states,
|
|
674
|
+
attention_mask,
|
|
675
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
676
|
+
scaling=self.scaling,
|
|
677
|
+
**kwargs,
|
|
678
|
+
)
|
|
682
679
|
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
f" {attn_output.size()}"
|
|
687
|
-
)
|
|
680
|
+
attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
|
|
681
|
+
attn_output = self.o_proj(attn_output)
|
|
682
|
+
return attn_output, attn_weights
|
|
688
683
|
|
|
689
|
-
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.v_head_dim)
|
|
690
|
-
attn_output = attn_output.transpose(1, 2)
|
|
691
|
-
attn_output = attn_output.reshape(batch_size, target_len, self.out_dim)
|
|
692
684
|
|
|
693
|
-
|
|
685
|
+
class ConditionalDetrMLP(nn.Module):
|
|
686
|
+
def __init__(self, config: ConditionalDetrConfig, hidden_size: int, intermediate_size: int):
|
|
687
|
+
super().__init__()
|
|
688
|
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
|
689
|
+
self.fc2 = nn.Linear(intermediate_size, hidden_size)
|
|
690
|
+
self.activation_fn = ACT2FN[config.activation_function]
|
|
691
|
+
self.activation_dropout = config.activation_dropout
|
|
692
|
+
self.dropout = config.dropout
|
|
694
693
|
|
|
695
|
-
|
|
694
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
695
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
696
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
697
|
+
hidden_states = self.fc2(hidden_states)
|
|
698
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
699
|
+
return hidden_states
|
|
696
700
|
|
|
697
701
|
|
|
698
|
-
|
|
699
|
-
class ConditionalDetrEncoderLayer(nn.Module):
|
|
702
|
+
class ConditionalDetrEncoderLayer(GradientCheckpointingLayer):
|
|
700
703
|
def __init__(self, config: ConditionalDetrConfig):
|
|
701
704
|
super().__init__()
|
|
702
|
-
self.
|
|
703
|
-
self.self_attn =
|
|
704
|
-
|
|
705
|
-
|
|
705
|
+
self.hidden_size = config.d_model
|
|
706
|
+
self.self_attn = ConditionalDetrSelfAttention(
|
|
707
|
+
config=config,
|
|
708
|
+
hidden_size=self.hidden_size,
|
|
709
|
+
num_attention_heads=config.encoder_attention_heads,
|
|
706
710
|
dropout=config.attention_dropout,
|
|
707
711
|
)
|
|
708
|
-
self.self_attn_layer_norm = nn.LayerNorm(self.
|
|
712
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
709
713
|
self.dropout = config.dropout
|
|
710
|
-
self.
|
|
711
|
-
self.
|
|
712
|
-
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
|
713
|
-
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
|
714
|
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
714
|
+
self.mlp = ConditionalDetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
|
|
715
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
715
716
|
|
|
716
717
|
def forward(
|
|
717
718
|
self,
|
|
718
719
|
hidden_states: torch.Tensor,
|
|
719
720
|
attention_mask: torch.Tensor,
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
):
|
|
721
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
722
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
723
|
+
) -> torch.Tensor:
|
|
723
724
|
"""
|
|
724
725
|
Args:
|
|
725
|
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len,
|
|
726
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
726
727
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
727
728
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
728
729
|
values.
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
733
|
-
returned tensors for more detail.
|
|
730
|
+
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
731
|
+
Spatial position embeddings (2D positional encodings of image locations), to be added to both
|
|
732
|
+
the queries and keys in self-attention (but not to values).
|
|
734
733
|
"""
|
|
735
734
|
residual = hidden_states
|
|
736
|
-
hidden_states,
|
|
735
|
+
hidden_states, _ = self.self_attn(
|
|
737
736
|
hidden_states=hidden_states,
|
|
738
737
|
attention_mask=attention_mask,
|
|
739
|
-
|
|
740
|
-
|
|
738
|
+
position_embeddings=spatial_position_embeddings,
|
|
739
|
+
**kwargs,
|
|
741
740
|
)
|
|
742
741
|
|
|
743
742
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
@@ -745,12 +744,7 @@ class ConditionalDetrEncoderLayer(nn.Module):
|
|
|
745
744
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
746
745
|
|
|
747
746
|
residual = hidden_states
|
|
748
|
-
hidden_states = self.
|
|
749
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
750
|
-
|
|
751
|
-
hidden_states = self.fc2(hidden_states)
|
|
752
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
753
|
-
|
|
747
|
+
hidden_states = self.mlp(hidden_states)
|
|
754
748
|
hidden_states = residual + hidden_states
|
|
755
749
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
756
750
|
|
|
@@ -759,80 +753,55 @@ class ConditionalDetrEncoderLayer(nn.Module):
|
|
|
759
753
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
760
754
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
761
755
|
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
if output_attentions:
|
|
765
|
-
outputs += (attn_weights,)
|
|
766
|
-
|
|
767
|
-
return outputs
|
|
756
|
+
return hidden_states
|
|
768
757
|
|
|
769
758
|
|
|
770
759
|
class ConditionalDetrDecoderLayer(GradientCheckpointingLayer):
|
|
771
760
|
def __init__(self, config: ConditionalDetrConfig):
|
|
772
761
|
super().__init__()
|
|
773
|
-
self.
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
self.sa_qpos_proj = nn.Linear(d_model, d_model)
|
|
779
|
-
self.sa_kcontent_proj = nn.Linear(d_model, d_model)
|
|
780
|
-
self.sa_kpos_proj = nn.Linear(d_model, d_model)
|
|
781
|
-
self.sa_v_proj = nn.Linear(d_model, d_model)
|
|
782
|
-
|
|
783
|
-
self.self_attn = ConditionalDetrAttention(
|
|
784
|
-
embed_dim=self.embed_dim,
|
|
785
|
-
out_dim=self.embed_dim,
|
|
786
|
-
num_heads=config.decoder_attention_heads,
|
|
762
|
+
self.hidden_size = config.d_model
|
|
763
|
+
self.self_attn = ConditionalDetrDecoderSelfAttention(
|
|
764
|
+
config=config,
|
|
765
|
+
hidden_size=self.hidden_size,
|
|
766
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
787
767
|
dropout=config.attention_dropout,
|
|
788
768
|
)
|
|
789
769
|
self.dropout = config.dropout
|
|
790
|
-
self.activation_fn = ACT2FN[config.activation_function]
|
|
791
|
-
self.activation_dropout = config.activation_dropout
|
|
792
770
|
|
|
793
|
-
self.self_attn_layer_norm = nn.LayerNorm(self.
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
self.ca_kpos_proj = nn.Linear(d_model, d_model)
|
|
800
|
-
self.ca_v_proj = nn.Linear(d_model, d_model)
|
|
801
|
-
self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
|
|
802
|
-
|
|
803
|
-
self.encoder_attn = ConditionalDetrAttention(
|
|
804
|
-
self.embed_dim * 2, self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout
|
|
771
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
772
|
+
self.encoder_attn = ConditionalDetrDecoderCrossAttention(
|
|
773
|
+
config=config,
|
|
774
|
+
hidden_size=self.hidden_size,
|
|
775
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
776
|
+
dropout=config.attention_dropout,
|
|
805
777
|
)
|
|
806
|
-
self.encoder_attn_layer_norm = nn.LayerNorm(self.
|
|
807
|
-
self.
|
|
808
|
-
self.
|
|
809
|
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
810
|
-
self.nhead = config.decoder_attention_heads
|
|
778
|
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
779
|
+
self.mlp = ConditionalDetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
|
|
780
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
811
781
|
|
|
812
782
|
def forward(
|
|
813
783
|
self,
|
|
814
784
|
hidden_states: torch.Tensor,
|
|
815
785
|
attention_mask: torch.Tensor | None = None,
|
|
816
|
-
|
|
786
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
817
787
|
query_position_embeddings: torch.Tensor | None = None,
|
|
818
788
|
query_sine_embed: torch.Tensor | None = None,
|
|
819
789
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
820
790
|
encoder_attention_mask: torch.Tensor | None = None,
|
|
821
|
-
output_attentions: bool | None = False,
|
|
822
791
|
is_first: bool | None = False,
|
|
823
|
-
|
|
792
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
793
|
+
) -> torch.Tensor:
|
|
824
794
|
"""
|
|
825
795
|
Args:
|
|
826
796
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
|
827
797
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
828
798
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
829
799
|
values.
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
in the cross-attention layer.
|
|
800
|
+
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
801
|
+
Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
|
|
833
802
|
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
834
803
|
object_queries that are added to the queries and keys
|
|
835
|
-
|
|
804
|
+
in the self-attention layer.
|
|
836
805
|
encoder_hidden_states (`torch.FloatTensor`):
|
|
837
806
|
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
|
838
807
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
@@ -844,108 +813,49 @@ class ConditionalDetrDecoderLayer(GradientCheckpointingLayer):
|
|
|
844
813
|
"""
|
|
845
814
|
residual = hidden_states
|
|
846
815
|
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
q_content = self.sa_qcontent_proj(
|
|
851
|
-
hidden_states
|
|
852
|
-
) # target is the input of the first decoder layer. zero by default.
|
|
853
|
-
q_pos = self.sa_qpos_proj(query_position_embeddings)
|
|
854
|
-
k_content = self.sa_kcontent_proj(hidden_states)
|
|
855
|
-
k_pos = self.sa_kpos_proj(query_position_embeddings)
|
|
856
|
-
v = self.sa_v_proj(hidden_states)
|
|
857
|
-
|
|
858
|
-
_, num_queries, n_model = q_content.shape
|
|
859
|
-
|
|
860
|
-
q = q_content + q_pos
|
|
861
|
-
k = k_content + k_pos
|
|
862
|
-
hidden_states, self_attn_weights = self.self_attn(
|
|
863
|
-
hidden_states=q,
|
|
816
|
+
hidden_states, _ = self.self_attn(
|
|
817
|
+
hidden_states=hidden_states,
|
|
818
|
+
query_position_embeddings=query_position_embeddings,
|
|
864
819
|
attention_mask=attention_mask,
|
|
865
|
-
|
|
866
|
-
value_states=v,
|
|
867
|
-
output_attentions=output_attentions,
|
|
820
|
+
**kwargs,
|
|
868
821
|
)
|
|
869
|
-
# ============ End of Self-Attention =============
|
|
870
822
|
|
|
871
823
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
872
824
|
hidden_states = residual + hidden_states
|
|
873
825
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
874
826
|
|
|
875
|
-
# ========== Begin of Cross-Attention =============
|
|
876
|
-
# Apply projections here
|
|
877
|
-
# shape: num_queries x batch_size x 256
|
|
878
|
-
q_content = self.ca_qcontent_proj(hidden_states)
|
|
879
|
-
k_content = self.ca_kcontent_proj(encoder_hidden_states)
|
|
880
|
-
v = self.ca_v_proj(encoder_hidden_states)
|
|
881
|
-
|
|
882
|
-
batch_size, num_queries, n_model = q_content.shape
|
|
883
|
-
_, source_len, _ = k_content.shape
|
|
884
|
-
|
|
885
|
-
k_pos = self.ca_kpos_proj(object_queries)
|
|
886
|
-
|
|
887
|
-
# For the first decoder layer, we concatenate the positional embedding predicted from
|
|
888
|
-
# the object query (the positional embedding) into the original query (key) in DETR.
|
|
889
|
-
if is_first:
|
|
890
|
-
q_pos = self.ca_qpos_proj(query_position_embeddings)
|
|
891
|
-
q = q_content + q_pos
|
|
892
|
-
k = k_content + k_pos
|
|
893
|
-
else:
|
|
894
|
-
q = q_content
|
|
895
|
-
k = k_content
|
|
896
|
-
|
|
897
|
-
q = q.view(batch_size, num_queries, self.nhead, n_model // self.nhead)
|
|
898
|
-
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
|
|
899
|
-
query_sine_embed = query_sine_embed.view(batch_size, num_queries, self.nhead, n_model // self.nhead)
|
|
900
|
-
q = torch.cat([q, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2)
|
|
901
|
-
k = k.view(batch_size, source_len, self.nhead, n_model // self.nhead)
|
|
902
|
-
k_pos = k_pos.view(batch_size, source_len, self.nhead, n_model // self.nhead)
|
|
903
|
-
k = torch.cat([k, k_pos], dim=3).view(batch_size, source_len, n_model * 2)
|
|
904
|
-
|
|
905
|
-
# Cross-Attention Block
|
|
906
|
-
cross_attn_weights = None
|
|
907
827
|
if encoder_hidden_states is not None:
|
|
908
828
|
residual = hidden_states
|
|
909
829
|
|
|
910
|
-
hidden_states,
|
|
911
|
-
hidden_states=
|
|
830
|
+
hidden_states, _ = self.encoder_attn(
|
|
831
|
+
hidden_states=hidden_states,
|
|
832
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
912
833
|
attention_mask=encoder_attention_mask,
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
834
|
+
query_sine_embed=query_sine_embed,
|
|
835
|
+
encoder_position_embeddings=spatial_position_embeddings,
|
|
836
|
+
# Only pass query_position_embeddings for the first layer
|
|
837
|
+
query_position_embeddings=query_position_embeddings if is_first else None,
|
|
838
|
+
**kwargs,
|
|
916
839
|
)
|
|
917
840
|
|
|
918
841
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
919
842
|
hidden_states = residual + hidden_states
|
|
920
843
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
921
844
|
|
|
922
|
-
# ============ End of Cross-Attention =============
|
|
923
|
-
|
|
924
845
|
# Fully Connected
|
|
925
846
|
residual = hidden_states
|
|
926
|
-
hidden_states = self.
|
|
927
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
928
|
-
hidden_states = self.fc2(hidden_states)
|
|
929
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
847
|
+
hidden_states = self.mlp(hidden_states)
|
|
930
848
|
hidden_states = residual + hidden_states
|
|
931
849
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
932
850
|
|
|
933
|
-
|
|
851
|
+
return hidden_states
|
|
934
852
|
|
|
935
|
-
if output_attentions:
|
|
936
|
-
outputs += (self_attn_weights, cross_attn_weights)
|
|
937
853
|
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with DetrMLPPredictionHead->MLP
|
|
942
|
-
class MLP(nn.Module):
|
|
854
|
+
class ConditionalDetrMLPPredictionHead(nn.Module):
|
|
943
855
|
"""
|
|
944
856
|
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
|
945
857
|
height and width of a bounding box w.r.t. an image.
|
|
946
858
|
|
|
947
|
-
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
948
|
-
|
|
949
859
|
"""
|
|
950
860
|
|
|
951
861
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
@@ -960,29 +870,202 @@ class MLP(nn.Module):
|
|
|
960
870
|
return x
|
|
961
871
|
|
|
962
872
|
|
|
873
|
+
class ConditionalDetrConvBlock(nn.Module):
|
|
874
|
+
"""Basic conv block: Conv3x3 -> GroupNorm -> Activation."""
|
|
875
|
+
|
|
876
|
+
def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"):
|
|
877
|
+
super().__init__()
|
|
878
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
879
|
+
self.norm = nn.GroupNorm(min(8, out_channels), out_channels)
|
|
880
|
+
self.activation = ACT2FN[activation]
|
|
881
|
+
|
|
882
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
883
|
+
return self.activation(self.norm(self.conv(x)))
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
class ConditionalDetrFPNFusionStage(nn.Module):
|
|
887
|
+
"""Single FPN fusion stage combining low-resolution features with high-resolution FPN features."""
|
|
888
|
+
|
|
889
|
+
def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"):
|
|
890
|
+
super().__init__()
|
|
891
|
+
self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1)
|
|
892
|
+
self.refine = ConditionalDetrConvBlock(current_channels, output_channels, activation)
|
|
893
|
+
|
|
894
|
+
def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor:
|
|
895
|
+
"""
|
|
896
|
+
Args:
|
|
897
|
+
features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in)
|
|
898
|
+
fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out)
|
|
899
|
+
|
|
900
|
+
Returns:
|
|
901
|
+
Fused and refined features, shape (B*Q, output_channels, H_out, W_out)
|
|
902
|
+
"""
|
|
903
|
+
fpn_features = self.fpn_adapter(fpn_features)
|
|
904
|
+
features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest")
|
|
905
|
+
return self.refine(fpn_features + features)
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
class ConditionalDetrMaskHeadSmallConv(nn.Module):
|
|
909
|
+
"""
|
|
910
|
+
Segmentation mask head that generates per-query masks using FPN-based progressive upsampling.
|
|
911
|
+
|
|
912
|
+
Combines attention maps (spatial localization) with encoder features (semantics) and progressively
|
|
913
|
+
upsamples through multiple scales, fusing with FPN features for high-resolution detail.
|
|
914
|
+
"""
|
|
915
|
+
|
|
916
|
+
def __init__(
|
|
917
|
+
self,
|
|
918
|
+
input_channels: int,
|
|
919
|
+
fpn_channels: list[int],
|
|
920
|
+
hidden_size: int,
|
|
921
|
+
activation_function: str = "relu",
|
|
922
|
+
):
|
|
923
|
+
super().__init__()
|
|
924
|
+
if input_channels % 8 != 0:
|
|
925
|
+
raise ValueError(f"input_channels must be divisible by 8, got {input_channels}")
|
|
926
|
+
|
|
927
|
+
self.conv1 = ConditionalDetrConvBlock(input_channels, input_channels, activation_function)
|
|
928
|
+
self.conv2 = ConditionalDetrConvBlock(input_channels, hidden_size // 2, activation_function)
|
|
929
|
+
|
|
930
|
+
# Progressive channel reduction: /2 -> /4 -> /8 -> /16
|
|
931
|
+
self.fpn_stages = nn.ModuleList(
|
|
932
|
+
[
|
|
933
|
+
ConditionalDetrFPNFusionStage(
|
|
934
|
+
fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function
|
|
935
|
+
),
|
|
936
|
+
ConditionalDetrFPNFusionStage(
|
|
937
|
+
fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function
|
|
938
|
+
),
|
|
939
|
+
ConditionalDetrFPNFusionStage(
|
|
940
|
+
fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function
|
|
941
|
+
),
|
|
942
|
+
]
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1)
|
|
946
|
+
|
|
947
|
+
def forward(
|
|
948
|
+
self,
|
|
949
|
+
features: torch.Tensor,
|
|
950
|
+
attention_masks: torch.Tensor,
|
|
951
|
+
fpn_features: list[torch.Tensor],
|
|
952
|
+
) -> torch.Tensor:
|
|
953
|
+
"""
|
|
954
|
+
Args:
|
|
955
|
+
features: Encoder output features, shape (batch_size, hidden_size, H, W)
|
|
956
|
+
attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W)
|
|
957
|
+
fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W)
|
|
958
|
+
|
|
959
|
+
Returns:
|
|
960
|
+
Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W)
|
|
961
|
+
"""
|
|
962
|
+
num_queries = attention_masks.shape[1]
|
|
963
|
+
|
|
964
|
+
# Expand to (batch_size * num_queries) dimension
|
|
965
|
+
features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1)
|
|
966
|
+
attention_masks = attention_masks.flatten(0, 1)
|
|
967
|
+
fpn_features = [
|
|
968
|
+
fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features
|
|
969
|
+
]
|
|
970
|
+
|
|
971
|
+
hidden_states = torch.cat([features, attention_masks], dim=1)
|
|
972
|
+
hidden_states = self.conv1(hidden_states)
|
|
973
|
+
hidden_states = self.conv2(hidden_states)
|
|
974
|
+
|
|
975
|
+
for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features):
|
|
976
|
+
hidden_states = fpn_stage(hidden_states, fpn_feat)
|
|
977
|
+
|
|
978
|
+
return self.output_conv(hidden_states)
|
|
979
|
+
|
|
980
|
+
|
|
981
|
+
class ConditionalDetrMHAttentionMap(nn.Module):
|
|
982
|
+
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
|
983
|
+
|
|
984
|
+
def __init__(
|
|
985
|
+
self,
|
|
986
|
+
hidden_size: int,
|
|
987
|
+
num_attention_heads: int,
|
|
988
|
+
dropout: float = 0.0,
|
|
989
|
+
bias: bool = True,
|
|
990
|
+
):
|
|
991
|
+
super().__init__()
|
|
992
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
993
|
+
self.scaling = self.head_dim**-0.5
|
|
994
|
+
self.attention_dropout = dropout
|
|
995
|
+
|
|
996
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
997
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
998
|
+
|
|
999
|
+
def forward(
|
|
1000
|
+
self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None
|
|
1001
|
+
):
|
|
1002
|
+
query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim)
|
|
1003
|
+
key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:])
|
|
1004
|
+
|
|
1005
|
+
query_states = self.q_proj(query_states).view(query_hidden_shape)
|
|
1006
|
+
key_states = nn.functional.conv2d(
|
|
1007
|
+
key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias
|
|
1008
|
+
).view(key_hidden_shape)
|
|
1009
|
+
|
|
1010
|
+
batch_size, num_queries, num_heads, head_dim = query_states.shape
|
|
1011
|
+
_, _, _, height, width = key_states.shape
|
|
1012
|
+
query_shape = (batch_size * num_heads, num_queries, head_dim)
|
|
1013
|
+
key_shape = (batch_size * num_heads, height * width, head_dim)
|
|
1014
|
+
attn_weights_shape = (batch_size, num_heads, num_queries, height, width)
|
|
1015
|
+
|
|
1016
|
+
query = query_states.transpose(1, 2).contiguous().view(query_shape)
|
|
1017
|
+
key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape)
|
|
1018
|
+
|
|
1019
|
+
attn_weights = (
|
|
1020
|
+
(torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2)
|
|
1021
|
+
)
|
|
1022
|
+
|
|
1023
|
+
if attention_mask is not None:
|
|
1024
|
+
attn_weights = attn_weights + attention_mask
|
|
1025
|
+
|
|
1026
|
+
attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size())
|
|
1027
|
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
1028
|
+
|
|
1029
|
+
return attn_weights
|
|
1030
|
+
|
|
1031
|
+
|
|
963
1032
|
@auto_docstring
|
|
964
|
-
# Copied from transformers.models.detr.modeling_detr.DetrPreTrainedModel with Detr->ConditionalDetr
|
|
965
1033
|
class ConditionalDetrPreTrainedModel(PreTrainedModel):
|
|
966
1034
|
config: ConditionalDetrConfig
|
|
967
1035
|
base_model_prefix = "model"
|
|
968
1036
|
main_input_name = "pixel_values"
|
|
969
1037
|
input_modalities = ("image",)
|
|
970
1038
|
_no_split_modules = [r"ConditionalDetrConvEncoder", r"ConditionalDetrEncoderLayer", r"ConditionalDetrDecoderLayer"]
|
|
1039
|
+
supports_gradient_checkpointing = True
|
|
1040
|
+
_supports_sdpa = True
|
|
1041
|
+
_supports_flash_attn = True
|
|
1042
|
+
_supports_attention_backend = True
|
|
1043
|
+
_supports_flex_attn = True # Uses create_bidirectional_masks for attention masking
|
|
1044
|
+
_keys_to_ignore_on_load_unexpected = [
|
|
1045
|
+
r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
|
|
1046
|
+
]
|
|
971
1047
|
|
|
972
1048
|
@torch.no_grad()
|
|
973
1049
|
def _init_weights(self, module):
|
|
974
1050
|
std = self.config.init_std
|
|
975
1051
|
xavier_std = self.config.init_xavier_std
|
|
976
1052
|
|
|
977
|
-
if isinstance(module,
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
1053
|
+
if isinstance(module, ConditionalDetrMaskHeadSmallConv):
|
|
1054
|
+
# ConditionalDetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers
|
|
1055
|
+
for m in module.modules():
|
|
1056
|
+
if isinstance(m, nn.Conv2d):
|
|
1057
|
+
init.kaiming_uniform_(m.weight, a=1)
|
|
1058
|
+
if m.bias is not None:
|
|
1059
|
+
init.constant_(m.bias, 0)
|
|
1060
|
+
elif isinstance(module, ConditionalDetrMHAttentionMap):
|
|
1061
|
+
init.zeros_(module.k_proj.bias)
|
|
1062
|
+
init.zeros_(module.q_proj.bias)
|
|
1063
|
+
init.xavier_uniform_(module.k_proj.weight, gain=xavier_std)
|
|
1064
|
+
init.xavier_uniform_(module.q_proj.weight, gain=xavier_std)
|
|
982
1065
|
elif isinstance(module, ConditionalDetrLearnedPositionEmbedding):
|
|
983
1066
|
init.uniform_(module.row_embeddings.weight)
|
|
984
1067
|
init.uniform_(module.column_embeddings.weight)
|
|
985
|
-
|
|
1068
|
+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
986
1069
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
987
1070
|
if module.bias is not None:
|
|
988
1071
|
init.zeros_(module.bias)
|
|
@@ -996,50 +1079,38 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
|
|
|
996
1079
|
init.zeros_(module.bias)
|
|
997
1080
|
|
|
998
1081
|
|
|
999
|
-
# Copied from transformers.models.detr.modeling_detr.DetrEncoder with Detr->ConditionalDetr,DETR->ConditionalDETR
|
|
1000
1082
|
class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
|
|
1001
1083
|
"""
|
|
1002
|
-
Transformer encoder
|
|
1003
|
-
[`ConditionalDetrEncoderLayer`].
|
|
1004
|
-
|
|
1005
|
-
The encoder updates the flattened feature map through multiple self-attention layers.
|
|
1006
|
-
|
|
1007
|
-
Small tweak for ConditionalDETR:
|
|
1008
|
-
|
|
1009
|
-
- object_queries are added to the forward pass.
|
|
1084
|
+
Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of
|
|
1085
|
+
[`ConditionalDetrEncoderLayer`] modules.
|
|
1010
1086
|
|
|
1011
1087
|
Args:
|
|
1012
|
-
config:
|
|
1088
|
+
config (`ConditionalDetrConfig`): Model configuration object.
|
|
1013
1089
|
"""
|
|
1014
1090
|
|
|
1091
|
+
_can_record_outputs = {"hidden_states": ConditionalDetrEncoderLayer, "attentions": ConditionalDetrSelfAttention}
|
|
1092
|
+
|
|
1015
1093
|
def __init__(self, config: ConditionalDetrConfig):
|
|
1016
1094
|
super().__init__(config)
|
|
1017
1095
|
|
|
1018
1096
|
self.dropout = config.dropout
|
|
1019
|
-
self.layerdrop = config.encoder_layerdrop
|
|
1020
|
-
|
|
1021
1097
|
self.layers = nn.ModuleList([ConditionalDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
1022
1098
|
|
|
1023
|
-
# in the original ConditionalDETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
|
|
1024
|
-
|
|
1025
1099
|
# Initialize weights and apply final processing
|
|
1026
1100
|
self.post_init()
|
|
1027
1101
|
|
|
1102
|
+
@check_model_inputs()
|
|
1028
1103
|
def forward(
|
|
1029
1104
|
self,
|
|
1030
1105
|
inputs_embeds=None,
|
|
1031
1106
|
attention_mask=None,
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
return_dict=None,
|
|
1036
|
-
**kwargs,
|
|
1037
|
-
):
|
|
1107
|
+
spatial_position_embeddings=None,
|
|
1108
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1109
|
+
) -> BaseModelOutput:
|
|
1038
1110
|
r"""
|
|
1039
1111
|
Args:
|
|
1040
1112
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
1041
1113
|
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
|
1042
|
-
|
|
1043
1114
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1044
1115
|
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
|
1045
1116
|
|
|
@@ -1047,69 +1118,44 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
|
|
|
1047
1118
|
- 0 for pixel features that are padding (i.e. **masked**).
|
|
1048
1119
|
|
|
1049
1120
|
[What are attention masks?](../glossary#attention-mask)
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
Object queries that are added to the queries in each self-attention layer.
|
|
1053
|
-
|
|
1054
|
-
output_attentions (`bool`, *optional*):
|
|
1055
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
1056
|
-
returned tensors for more detail.
|
|
1057
|
-
output_hidden_states (`bool`, *optional*):
|
|
1058
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
1059
|
-
for more detail.
|
|
1060
|
-
return_dict (`bool`, *optional*):
|
|
1061
|
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
1121
|
+
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
1122
|
+
Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
|
|
1062
1123
|
"""
|
|
1063
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1064
|
-
output_hidden_states = (
|
|
1065
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1066
|
-
)
|
|
1067
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1068
|
-
|
|
1069
1124
|
hidden_states = inputs_embeds
|
|
1070
1125
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
1071
1126
|
|
|
1072
1127
|
# expand attention_mask
|
|
1073
1128
|
if attention_mask is not None:
|
|
1074
1129
|
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
1075
|
-
attention_mask =
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
if output_hidden_states:
|
|
1081
|
-
encoder_states = encoder_states + (hidden_states,)
|
|
1082
|
-
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
|
1083
|
-
to_drop = False
|
|
1084
|
-
if self.training:
|
|
1085
|
-
dropout_probability = torch.rand([])
|
|
1086
|
-
if dropout_probability < self.layerdrop: # skip the layer
|
|
1087
|
-
to_drop = True
|
|
1130
|
+
attention_mask = create_bidirectional_mask(
|
|
1131
|
+
config=self.config,
|
|
1132
|
+
input_embeds=inputs_embeds,
|
|
1133
|
+
attention_mask=attention_mask,
|
|
1134
|
+
)
|
|
1088
1135
|
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
)
|
|
1136
|
+
for encoder_layer in self.layers:
|
|
1137
|
+
# we add spatial_position_embeddings as extra input to the encoder_layer
|
|
1138
|
+
hidden_states = encoder_layer(
|
|
1139
|
+
hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs
|
|
1140
|
+
)
|
|
1141
|
+
|
|
1142
|
+
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
1143
|
+
|
|
1144
|
+
|
|
1145
|
+
# function to generate sine positional embedding for 2d coordinates
|
|
1146
|
+
def gen_sine_position_embeddings(pos_tensor, d_model):
|
|
1147
|
+
scale = 2 * math.pi
|
|
1148
|
+
dim = d_model // 2
|
|
1149
|
+
dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
|
|
1150
|
+
dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
|
|
1151
|
+
x_embed = pos_tensor[:, :, 0] * scale
|
|
1152
|
+
y_embed = pos_tensor[:, :, 1] * scale
|
|
1153
|
+
pos_x = x_embed[:, :, None] / dim_t
|
|
1154
|
+
pos_y = y_embed[:, :, None] / dim_t
|
|
1155
|
+
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
1156
|
+
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
1157
|
+
pos = torch.cat((pos_y, pos_x), dim=2)
|
|
1158
|
+
return pos.to(pos_tensor.dtype)
|
|
1113
1159
|
|
|
1114
1160
|
|
|
1115
1161
|
class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|
@@ -1127,39 +1173,44 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|
|
1127
1173
|
config: ConditionalDetrConfig
|
|
1128
1174
|
"""
|
|
1129
1175
|
|
|
1176
|
+
_can_record_outputs = {
|
|
1177
|
+
"hidden_states": ConditionalDetrDecoderLayer,
|
|
1178
|
+
"attentions": OutputRecorder(ConditionalDetrDecoderSelfAttention, layer_name="self_attn", index=1),
|
|
1179
|
+
"cross_attentions": OutputRecorder(ConditionalDetrDecoderCrossAttention, layer_name="encoder_attn", index=1),
|
|
1180
|
+
}
|
|
1181
|
+
|
|
1130
1182
|
def __init__(self, config: ConditionalDetrConfig):
|
|
1131
1183
|
super().__init__(config)
|
|
1184
|
+
self.hidden_size = config.d_model
|
|
1185
|
+
|
|
1132
1186
|
self.dropout = config.dropout
|
|
1133
1187
|
self.layerdrop = config.decoder_layerdrop
|
|
1134
1188
|
|
|
1135
1189
|
self.layers = nn.ModuleList([ConditionalDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
|
|
1136
1190
|
# in Conditional DETR, the decoder uses layernorm after the last decoder layer output
|
|
1137
1191
|
self.layernorm = nn.LayerNorm(config.d_model)
|
|
1138
|
-
d_model = config.d_model
|
|
1139
|
-
self.gradient_checkpointing = False
|
|
1140
1192
|
|
|
1141
1193
|
# query_scale is the FFN applied on f to generate transformation T
|
|
1142
|
-
self.query_scale =
|
|
1143
|
-
self.ref_point_head =
|
|
1194
|
+
self.query_scale = ConditionalDetrMLPPredictionHead(self.hidden_size, self.hidden_size, self.hidden_size, 2)
|
|
1195
|
+
self.ref_point_head = ConditionalDetrMLPPredictionHead(self.hidden_size, self.hidden_size, 2, 2)
|
|
1144
1196
|
for layer_id in range(config.decoder_layers - 1):
|
|
1145
|
-
|
|
1197
|
+
# Set q_pos_proj to None for layers after the first (only first layer uses query position embeddings)
|
|
1198
|
+
self.layers[layer_id + 1].encoder_attn.q_pos_proj = None
|
|
1146
1199
|
|
|
1147
1200
|
# Initialize weights and apply final processing
|
|
1148
1201
|
self.post_init()
|
|
1149
1202
|
|
|
1203
|
+
@check_model_inputs()
|
|
1150
1204
|
def forward(
|
|
1151
1205
|
self,
|
|
1152
1206
|
inputs_embeds=None,
|
|
1153
1207
|
attention_mask=None,
|
|
1154
1208
|
encoder_hidden_states=None,
|
|
1155
1209
|
encoder_attention_mask=None,
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
return_dict=None,
|
|
1161
|
-
**kwargs,
|
|
1162
|
-
):
|
|
1210
|
+
spatial_position_embeddings=None,
|
|
1211
|
+
object_queries_position_embeddings=None,
|
|
1212
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1213
|
+
) -> ConditionalDetrDecoderOutput:
|
|
1163
1214
|
r"""
|
|
1164
1215
|
Args:
|
|
1165
1216
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
@@ -1182,46 +1233,28 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|
|
1182
1233
|
- 1 for pixels that are real (i.e. **not masked**),
|
|
1183
1234
|
- 0 for pixels that are padding (i.e. **masked**).
|
|
1184
1235
|
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1236
|
+
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1237
|
+
Spatial position embeddings that are added to the queries and keys in each cross-attention layer.
|
|
1238
|
+
object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
1188
1239
|
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
|
|
1189
|
-
output_attentions (`bool`, *optional*):
|
|
1190
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
1191
|
-
returned tensors for more detail.
|
|
1192
|
-
output_hidden_states (`bool`, *optional*):
|
|
1193
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
1194
|
-
for more detail.
|
|
1195
|
-
return_dict (`bool`, *optional*):
|
|
1196
|
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
1197
1240
|
"""
|
|
1198
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1199
|
-
output_hidden_states = (
|
|
1200
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1201
|
-
)
|
|
1202
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1203
|
-
|
|
1204
1241
|
if inputs_embeds is not None:
|
|
1205
1242
|
hidden_states = inputs_embeds
|
|
1206
|
-
input_shape = inputs_embeds.size()[:-1]
|
|
1207
1243
|
|
|
1208
1244
|
# expand encoder attention mask
|
|
1209
1245
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
|
1210
1246
|
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
1211
|
-
encoder_attention_mask =
|
|
1212
|
-
|
|
1247
|
+
encoder_attention_mask = create_bidirectional_mask(
|
|
1248
|
+
self.config,
|
|
1249
|
+
inputs_embeds,
|
|
1250
|
+
encoder_attention_mask,
|
|
1213
1251
|
)
|
|
1214
1252
|
|
|
1215
1253
|
# optional intermediate hidden states
|
|
1216
1254
|
intermediate = () if self.config.auxiliary_loss else None
|
|
1217
1255
|
|
|
1218
|
-
# decoder layers
|
|
1219
|
-
all_hidden_states = () if output_hidden_states else None
|
|
1220
|
-
all_self_attns = () if output_attentions else None
|
|
1221
|
-
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
1222
|
-
|
|
1223
1256
|
reference_points_before_sigmoid = self.ref_point_head(
|
|
1224
|
-
|
|
1257
|
+
object_queries_position_embeddings
|
|
1225
1258
|
) # [num_queries, batch_size, 2]
|
|
1226
1259
|
reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1)
|
|
1227
1260
|
obj_center = reference_points[..., :2].transpose(0, 1)
|
|
@@ -1229,9 +1262,6 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|
|
1229
1262
|
query_sine_embed_before_transformation = gen_sine_position_embeddings(obj_center, self.config.d_model)
|
|
1230
1263
|
|
|
1231
1264
|
for idx, decoder_layer in enumerate(self.layers):
|
|
1232
|
-
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
|
1233
|
-
if output_hidden_states:
|
|
1234
|
-
all_hidden_states += (hidden_states,)
|
|
1235
1265
|
if self.training:
|
|
1236
1266
|
dropout_probability = torch.rand([])
|
|
1237
1267
|
if dropout_probability < self.layerdrop:
|
|
@@ -1243,59 +1273,31 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|
|
1243
1273
|
# apply transformation
|
|
1244
1274
|
query_sine_embed = query_sine_embed_before_transformation * pos_transformation
|
|
1245
1275
|
|
|
1246
|
-
|
|
1276
|
+
hidden_states = decoder_layer(
|
|
1247
1277
|
hidden_states,
|
|
1248
|
-
None,
|
|
1249
|
-
|
|
1250
|
-
|
|
1278
|
+
None,
|
|
1279
|
+
spatial_position_embeddings,
|
|
1280
|
+
object_queries_position_embeddings,
|
|
1251
1281
|
query_sine_embed,
|
|
1252
1282
|
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
|
1253
1283
|
encoder_attention_mask=encoder_attention_mask,
|
|
1254
|
-
output_attentions=output_attentions,
|
|
1255
1284
|
is_first=(idx == 0),
|
|
1285
|
+
**kwargs,
|
|
1256
1286
|
)
|
|
1257
1287
|
|
|
1258
|
-
hidden_states = layer_outputs[0]
|
|
1259
|
-
|
|
1260
1288
|
if self.config.auxiliary_loss:
|
|
1261
1289
|
hidden_states = self.layernorm(hidden_states)
|
|
1262
1290
|
intermediate += (hidden_states,)
|
|
1263
1291
|
|
|
1264
|
-
if output_attentions:
|
|
1265
|
-
all_self_attns += (layer_outputs[1],)
|
|
1266
|
-
|
|
1267
|
-
if encoder_hidden_states is not None:
|
|
1268
|
-
all_cross_attentions += (layer_outputs[2],)
|
|
1269
|
-
|
|
1270
1292
|
# finally, apply layernorm
|
|
1271
1293
|
hidden_states = self.layernorm(hidden_states)
|
|
1272
1294
|
|
|
1273
|
-
# add hidden states from the last decoder layer
|
|
1274
|
-
if output_hidden_states:
|
|
1275
|
-
all_hidden_states += (hidden_states,)
|
|
1276
|
-
|
|
1277
1295
|
# stack intermediate decoder activations
|
|
1278
1296
|
if self.config.auxiliary_loss:
|
|
1279
1297
|
intermediate = torch.stack(intermediate)
|
|
1280
1298
|
|
|
1281
|
-
if not return_dict:
|
|
1282
|
-
return tuple(
|
|
1283
|
-
v
|
|
1284
|
-
for v in [
|
|
1285
|
-
hidden_states,
|
|
1286
|
-
all_hidden_states,
|
|
1287
|
-
all_self_attns,
|
|
1288
|
-
all_cross_attentions,
|
|
1289
|
-
intermediate,
|
|
1290
|
-
reference_points,
|
|
1291
|
-
]
|
|
1292
|
-
if v is not None
|
|
1293
|
-
)
|
|
1294
1299
|
return ConditionalDetrDecoderOutput(
|
|
1295
1300
|
last_hidden_state=hidden_states,
|
|
1296
|
-
hidden_states=all_hidden_states,
|
|
1297
|
-
attentions=all_self_attns,
|
|
1298
|
-
cross_attentions=all_cross_attentions,
|
|
1299
1301
|
intermediate_hidden_states=intermediate,
|
|
1300
1302
|
reference_points=reference_points,
|
|
1301
1303
|
)
|
|
@@ -1303,23 +1305,24 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|
|
1303
1305
|
|
|
1304
1306
|
@auto_docstring(
|
|
1305
1307
|
custom_intro="""
|
|
1306
|
-
The bare
|
|
1307
|
-
|
|
1308
|
+
The bare CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
|
|
1309
|
+
any specific head on top.
|
|
1308
1310
|
"""
|
|
1309
1311
|
)
|
|
1310
1312
|
class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
1311
1313
|
def __init__(self, config: ConditionalDetrConfig):
|
|
1312
1314
|
super().__init__(config)
|
|
1313
1315
|
|
|
1314
|
-
|
|
1315
|
-
backbone = ConditionalDetrConvEncoder(config)
|
|
1316
|
-
object_queries = build_position_encoding(config)
|
|
1317
|
-
self.backbone = ConditionalDetrConvModel(backbone, object_queries)
|
|
1318
|
-
|
|
1319
|
-
# Create projection layer
|
|
1320
|
-
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
|
1316
|
+
self.backbone = ConditionalDetrConvEncoder(config)
|
|
1321
1317
|
|
|
1318
|
+
if config.position_embedding_type == "sine":
|
|
1319
|
+
self.position_embedding = ConditionalDetrSinePositionEmbedding(config.d_model // 2, normalize=True)
|
|
1320
|
+
elif config.position_embedding_type == "learned":
|
|
1321
|
+
self.position_embedding = ConditionalDetrLearnedPositionEmbedding(config.d_model // 2)
|
|
1322
|
+
else:
|
|
1323
|
+
raise ValueError(f"Not supported {config.position_embedding_type}")
|
|
1322
1324
|
self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
|
|
1325
|
+
self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
|
1323
1326
|
|
|
1324
1327
|
self.encoder = ConditionalDetrEncoder(config)
|
|
1325
1328
|
self.decoder = ConditionalDetrDecoder(config)
|
|
@@ -1328,14 +1331,15 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
|
1328
1331
|
self.post_init()
|
|
1329
1332
|
|
|
1330
1333
|
def freeze_backbone(self):
|
|
1331
|
-
for
|
|
1334
|
+
for _, param in self.backbone.model.named_parameters():
|
|
1332
1335
|
param.requires_grad_(False)
|
|
1333
1336
|
|
|
1334
1337
|
def unfreeze_backbone(self):
|
|
1335
|
-
for
|
|
1338
|
+
for _, param in self.backbone.model.named_parameters():
|
|
1336
1339
|
param.requires_grad_(True)
|
|
1337
1340
|
|
|
1338
1341
|
@auto_docstring
|
|
1342
|
+
@can_return_tuple
|
|
1339
1343
|
def forward(
|
|
1340
1344
|
self,
|
|
1341
1345
|
pixel_values: torch.FloatTensor,
|
|
@@ -1344,11 +1348,8 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
|
1344
1348
|
encoder_outputs: torch.FloatTensor | None = None,
|
|
1345
1349
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1346
1350
|
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
return_dict: bool | None = None,
|
|
1350
|
-
**kwargs,
|
|
1351
|
-
) -> tuple[torch.FloatTensor] | ConditionalDetrModelOutput:
|
|
1351
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1352
|
+
) -> ConditionalDetrModelOutput:
|
|
1352
1353
|
r"""
|
|
1353
1354
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
1354
1355
|
Not used by default. Can be used to mask object queries.
|
|
@@ -1384,12 +1385,6 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
|
1384
1385
|
>>> list(last_hidden_states.shape)
|
|
1385
1386
|
[1, 300, 256]
|
|
1386
1387
|
```"""
|
|
1387
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1388
|
-
output_hidden_states = (
|
|
1389
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1390
|
-
)
|
|
1391
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1392
|
-
|
|
1393
1388
|
batch_size, num_channels, height, width = pixel_values.shape
|
|
1394
1389
|
device = pixel_values.device
|
|
1395
1390
|
|
|
@@ -1399,7 +1394,7 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
|
1399
1394
|
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
|
|
1400
1395
|
# pixel_values should be of shape (batch_size, num_channels, height, width)
|
|
1401
1396
|
# pixel_mask should be of shape (batch_size, height, width)
|
|
1402
|
-
features
|
|
1397
|
+
features = self.backbone(pixel_values, pixel_mask)
|
|
1403
1398
|
|
|
1404
1399
|
# get final feature map and downsampled mask
|
|
1405
1400
|
feature_map, mask = features[-1]
|
|
@@ -1410,53 +1405,52 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
|
1410
1405
|
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
|
1411
1406
|
projected_feature_map = self.input_projection(feature_map)
|
|
1412
1407
|
|
|
1413
|
-
#
|
|
1408
|
+
# Generate position embeddings
|
|
1409
|
+
spatial_position_embeddings = self.position_embedding(
|
|
1410
|
+
shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
|
|
1411
|
+
)
|
|
1412
|
+
|
|
1413
|
+
# Third, flatten the feature map of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
|
1414
1414
|
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
|
1415
1415
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
|
1416
|
-
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
|
|
1417
1416
|
|
|
1418
1417
|
flattened_mask = mask.flatten(1)
|
|
1419
1418
|
|
|
1420
|
-
# Fourth, sent flattened_features + flattened_mask +
|
|
1419
|
+
# Fourth, sent flattened_features + flattened_mask + spatial_position_embeddings through encoder
|
|
1421
1420
|
# flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
|
|
1422
1421
|
# flattened_mask is a Tensor of shape (batch_size, height*width)
|
|
1423
1422
|
if encoder_outputs is None:
|
|
1424
1423
|
encoder_outputs = self.encoder(
|
|
1425
1424
|
inputs_embeds=flattened_features,
|
|
1426
1425
|
attention_mask=flattened_mask,
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
output_hidden_states=output_hidden_states,
|
|
1430
|
-
return_dict=return_dict,
|
|
1426
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1427
|
+
**kwargs,
|
|
1431
1428
|
)
|
|
1432
|
-
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
|
1433
|
-
elif
|
|
1429
|
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
|
1430
|
+
elif not isinstance(encoder_outputs, BaseModelOutput):
|
|
1434
1431
|
encoder_outputs = BaseModelOutput(
|
|
1435
1432
|
last_hidden_state=encoder_outputs[0],
|
|
1436
1433
|
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1437
1434
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1438
1435
|
)
|
|
1439
1436
|
|
|
1440
|
-
# Fifth, sent query embeddings
|
|
1441
|
-
|
|
1442
|
-
|
|
1437
|
+
# Fifth, sent query embeddings through the decoder (which is conditioned on the encoder output)
|
|
1438
|
+
object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(
|
|
1439
|
+
batch_size, 1, 1
|
|
1440
|
+
)
|
|
1441
|
+
queries = torch.zeros_like(object_queries_position_embeddings)
|
|
1443
1442
|
|
|
1444
1443
|
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
|
|
1445
1444
|
decoder_outputs = self.decoder(
|
|
1446
1445
|
inputs_embeds=queries,
|
|
1447
1446
|
attention_mask=None,
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
encoder_hidden_states=encoder_outputs
|
|
1447
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1448
|
+
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
1449
|
+
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
1451
1450
|
encoder_attention_mask=flattened_mask,
|
|
1452
|
-
|
|
1453
|
-
output_hidden_states=output_hidden_states,
|
|
1454
|
-
return_dict=return_dict,
|
|
1451
|
+
**kwargs,
|
|
1455
1452
|
)
|
|
1456
1453
|
|
|
1457
|
-
if not return_dict:
|
|
1458
|
-
return decoder_outputs + encoder_outputs
|
|
1459
|
-
|
|
1460
1454
|
return ConditionalDetrModelOutput(
|
|
1461
1455
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
1462
1456
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
@@ -1470,45 +1464,26 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
|
1470
1464
|
)
|
|
1471
1465
|
|
|
1472
1466
|
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
1480
|
-
|
|
1481
|
-
"""
|
|
1482
|
-
|
|
1483
|
-
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
1484
|
-
super().__init__()
|
|
1485
|
-
self.num_layers = num_layers
|
|
1486
|
-
h = [hidden_dim] * (num_layers - 1)
|
|
1487
|
-
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
|
1488
|
-
|
|
1489
|
-
def forward(self, x):
|
|
1490
|
-
for i, layer in enumerate(self.layers):
|
|
1491
|
-
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
1492
|
-
return x
|
|
1467
|
+
def inverse_sigmoid(x, eps=1e-5):
|
|
1468
|
+
x = x.clamp(min=0, max=1)
|
|
1469
|
+
x1 = x.clamp(min=eps)
|
|
1470
|
+
x2 = (1 - x).clamp(min=eps)
|
|
1471
|
+
return torch.log(x1 / x2)
|
|
1493
1472
|
|
|
1494
1473
|
|
|
1495
1474
|
@auto_docstring(
|
|
1496
1475
|
custom_intro="""
|
|
1497
|
-
|
|
1498
|
-
|
|
1476
|
+
CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
|
|
1477
|
+
such as COCO detection.
|
|
1499
1478
|
"""
|
|
1500
1479
|
)
|
|
1501
1480
|
class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
1502
1481
|
def __init__(self, config: ConditionalDetrConfig):
|
|
1503
1482
|
super().__init__(config)
|
|
1504
1483
|
|
|
1505
|
-
#
|
|
1484
|
+
# CONDITIONAL_DETR encoder-decoder model
|
|
1506
1485
|
self.model = ConditionalDetrModel(config)
|
|
1507
|
-
|
|
1508
|
-
# Object detection heads
|
|
1509
|
-
self.class_labels_classifier = nn.Linear(
|
|
1510
|
-
config.d_model, config.num_labels
|
|
1511
|
-
) # We add one for the "no object" class
|
|
1486
|
+
self.class_labels_classifier = nn.Linear(config.d_model, config.num_labels)
|
|
1512
1487
|
self.bbox_predictor = ConditionalDetrMLPPredictionHead(
|
|
1513
1488
|
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
|
|
1514
1489
|
)
|
|
@@ -1516,11 +1491,8 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1516
1491
|
# Initialize weights and apply final processing
|
|
1517
1492
|
self.post_init()
|
|
1518
1493
|
|
|
1519
|
-
# taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/conditional_detr.py
|
|
1520
|
-
def _set_aux_loss(self, outputs_class, outputs_coord):
|
|
1521
|
-
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
|
1522
|
-
|
|
1523
1494
|
@auto_docstring
|
|
1495
|
+
@can_return_tuple
|
|
1524
1496
|
def forward(
|
|
1525
1497
|
self,
|
|
1526
1498
|
pixel_values: torch.FloatTensor,
|
|
@@ -1530,11 +1502,8 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1530
1502
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1531
1503
|
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1532
1504
|
labels: list[dict] | None = None,
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
return_dict: bool | None = None,
|
|
1536
|
-
**kwargs,
|
|
1537
|
-
) -> tuple[torch.FloatTensor] | ConditionalDetrObjectDetectionOutput:
|
|
1505
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1506
|
+
) -> ConditionalDetrObjectDetectionOutput:
|
|
1538
1507
|
r"""
|
|
1539
1508
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
1540
1509
|
Not used by default. Can be used to mask object queries.
|
|
@@ -1584,8 +1553,6 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1584
1553
|
Detected remote with confidence 0.683 at location [334.48, 73.49, 366.37, 190.01]
|
|
1585
1554
|
Detected couch with confidence 0.535 at location [0.52, 1.19, 640.35, 475.1]
|
|
1586
1555
|
```"""
|
|
1587
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1588
|
-
|
|
1589
1556
|
# First, sent images through CONDITIONAL_DETR base model to obtain encoder + decoder outputs
|
|
1590
1557
|
outputs = self.model(
|
|
1591
1558
|
pixel_values,
|
|
@@ -1594,9 +1561,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1594
1561
|
encoder_outputs=encoder_outputs,
|
|
1595
1562
|
inputs_embeds=inputs_embeds,
|
|
1596
1563
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
1597
|
-
|
|
1598
|
-
output_hidden_states=output_hidden_states,
|
|
1599
|
-
return_dict=return_dict,
|
|
1564
|
+
**kwargs,
|
|
1600
1565
|
)
|
|
1601
1566
|
|
|
1602
1567
|
sequence_output = outputs[0]
|
|
@@ -1604,11 +1569,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1604
1569
|
# class logits + predicted bounding boxes
|
|
1605
1570
|
logits = self.class_labels_classifier(sequence_output)
|
|
1606
1571
|
|
|
1607
|
-
|
|
1608
|
-
# are not specified, otherwise it will be another index which is hard to determine.
|
|
1609
|
-
# Leave it as is, because it's not a common case to use
|
|
1610
|
-
# return_dict=False + output_attentions=True / output_hidden_states=True
|
|
1611
|
-
reference = outputs.reference_points if return_dict else outputs[-2]
|
|
1572
|
+
reference = outputs.reference_points
|
|
1612
1573
|
reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
|
|
1613
1574
|
|
|
1614
1575
|
hs = sequence_output
|
|
@@ -1622,7 +1583,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1622
1583
|
outputs_class, outputs_coord = None, None
|
|
1623
1584
|
if self.config.auxiliary_loss:
|
|
1624
1585
|
outputs_coords = []
|
|
1625
|
-
intermediate = outputs.intermediate_hidden_states
|
|
1586
|
+
intermediate = outputs.intermediate_hidden_states
|
|
1626
1587
|
outputs_class = self.class_labels_classifier(intermediate)
|
|
1627
1588
|
for lvl in range(intermediate.shape[0]):
|
|
1628
1589
|
tmp = self.bbox_predictor(intermediate[lvl])
|
|
@@ -1634,13 +1595,6 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1634
1595
|
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
|
1635
1596
|
)
|
|
1636
1597
|
|
|
1637
|
-
if not return_dict:
|
|
1638
|
-
if auxiliary_outputs is not None:
|
|
1639
|
-
output = (logits, pred_boxes) + auxiliary_outputs + outputs
|
|
1640
|
-
else:
|
|
1641
|
-
output = (logits, pred_boxes) + outputs
|
|
1642
|
-
return ((loss, loss_dict) + output) if loss is not None else output
|
|
1643
|
-
|
|
1644
1598
|
return ConditionalDetrObjectDetectionOutput(
|
|
1645
1599
|
loss=loss,
|
|
1646
1600
|
loss_dict=loss_dict,
|
|
@@ -1656,14 +1610,38 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1656
1610
|
encoder_attentions=outputs.encoder_attentions,
|
|
1657
1611
|
)
|
|
1658
1612
|
|
|
1613
|
+
# taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/conditional_detr.py
|
|
1614
|
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
|
1615
|
+
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
|
1616
|
+
|
|
1659
1617
|
|
|
1660
1618
|
@auto_docstring(
|
|
1661
1619
|
custom_intro="""
|
|
1662
|
-
|
|
1663
|
-
|
|
1620
|
+
CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
|
|
1621
|
+
such as COCO panoptic.
|
|
1664
1622
|
"""
|
|
1665
1623
|
)
|
|
1666
1624
|
class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
1625
|
+
_checkpoint_conversion_mapping = {
|
|
1626
|
+
"bbox_attention.q_linear": "bbox_attention.q_proj",
|
|
1627
|
+
"bbox_attention.k_linear": "bbox_attention.k_proj",
|
|
1628
|
+
# Mask head refactor
|
|
1629
|
+
"mask_head.lay1": "mask_head.conv1.conv",
|
|
1630
|
+
"mask_head.gn1": "mask_head.conv1.norm",
|
|
1631
|
+
"mask_head.lay2": "mask_head.conv2.conv",
|
|
1632
|
+
"mask_head.gn2": "mask_head.conv2.norm",
|
|
1633
|
+
"mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter",
|
|
1634
|
+
"mask_head.lay3": "mask_head.fpn_stages.0.refine.conv",
|
|
1635
|
+
"mask_head.gn3": "mask_head.fpn_stages.0.refine.norm",
|
|
1636
|
+
"mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter",
|
|
1637
|
+
"mask_head.lay4": "mask_head.fpn_stages.1.refine.conv",
|
|
1638
|
+
"mask_head.gn4": "mask_head.fpn_stages.1.refine.norm",
|
|
1639
|
+
"mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter",
|
|
1640
|
+
"mask_head.lay5": "mask_head.fpn_stages.2.refine.conv",
|
|
1641
|
+
"mask_head.gn5": "mask_head.fpn_stages.2.refine.norm",
|
|
1642
|
+
"mask_head.out_lay": "mask_head.output_conv",
|
|
1643
|
+
}
|
|
1644
|
+
|
|
1667
1645
|
def __init__(self, config: ConditionalDetrConfig):
|
|
1668
1646
|
super().__init__(config)
|
|
1669
1647
|
|
|
@@ -1672,20 +1650,21 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
|
1672
1650
|
|
|
1673
1651
|
# segmentation head
|
|
1674
1652
|
hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
|
|
1675
|
-
intermediate_channel_sizes = self.conditional_detr.model.backbone.
|
|
1653
|
+
intermediate_channel_sizes = self.conditional_detr.model.backbone.intermediate_channel_sizes
|
|
1676
1654
|
|
|
1677
1655
|
self.mask_head = ConditionalDetrMaskHeadSmallConv(
|
|
1678
|
-
hidden_size + number_of_heads,
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
|
|
1656
|
+
input_channels=hidden_size + number_of_heads,
|
|
1657
|
+
fpn_channels=intermediate_channel_sizes[::-1][-3:],
|
|
1658
|
+
hidden_size=hidden_size,
|
|
1659
|
+
activation_function=config.activation_function,
|
|
1683
1660
|
)
|
|
1684
1661
|
|
|
1662
|
+
self.bbox_attention = ConditionalDetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0)
|
|
1685
1663
|
# Initialize weights and apply final processing
|
|
1686
1664
|
self.post_init()
|
|
1687
1665
|
|
|
1688
1666
|
@auto_docstring
|
|
1667
|
+
@can_return_tuple
|
|
1689
1668
|
def forward(
|
|
1690
1669
|
self,
|
|
1691
1670
|
pixel_values: torch.FloatTensor,
|
|
@@ -1695,20 +1674,20 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
|
1695
1674
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1696
1675
|
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1697
1676
|
labels: list[dict] | None = None,
|
|
1698
|
-
|
|
1699
|
-
output_hidden_states: bool | None = None,
|
|
1700
|
-
return_dict: bool | None = None,
|
|
1701
|
-
**kwargs,
|
|
1677
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1702
1678
|
) -> tuple[torch.FloatTensor] | ConditionalDetrSegmentationOutput:
|
|
1703
1679
|
r"""
|
|
1704
1680
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
1705
|
-
|
|
1681
|
+
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
|
|
1682
|
+
|
|
1683
|
+
- 1 for queries that are **not masked**,
|
|
1684
|
+
- 0 for queries that are **masked**.
|
|
1706
1685
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1707
|
-
|
|
1708
|
-
|
|
1686
|
+
Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires
|
|
1687
|
+
multi-scale features from the backbone that are not available when bypassing it with inputs_embeds.
|
|
1709
1688
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1710
1689
|
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1711
|
-
embedded representation.
|
|
1690
|
+
embedded representation. Useful for tasks that require custom query initialization.
|
|
1712
1691
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1713
1692
|
Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
|
|
1714
1693
|
dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
|
|
@@ -1721,26 +1700,21 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
|
1721
1700
|
|
|
1722
1701
|
```python
|
|
1723
1702
|
>>> import io
|
|
1724
|
-
>>> import
|
|
1703
|
+
>>> import httpx
|
|
1704
|
+
>>> from io import BytesIO
|
|
1725
1705
|
>>> from PIL import Image
|
|
1726
1706
|
>>> import torch
|
|
1727
1707
|
>>> import numpy
|
|
1728
1708
|
|
|
1729
|
-
>>> from transformers import
|
|
1730
|
-
... AutoImageProcessor,
|
|
1731
|
-
... ConditionalDetrConfig,
|
|
1732
|
-
... ConditionalDetrForSegmentation,
|
|
1733
|
-
... )
|
|
1709
|
+
>>> from transformers import AutoImageProcessor, ConditionalDetrForSegmentation
|
|
1734
1710
|
>>> from transformers.image_transforms import rgb_to_id
|
|
1735
1711
|
|
|
1736
1712
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
1737
|
-
>>>
|
|
1713
|
+
>>> with httpx.stream("GET", url) as response:
|
|
1714
|
+
... image = Image.open(BytesIO(response.read()))
|
|
1738
1715
|
|
|
1739
|
-
>>> image_processor = AutoImageProcessor.from_pretrained("
|
|
1740
|
-
|
|
1741
|
-
>>> # randomly initialize all weights of the model
|
|
1742
|
-
>>> config = ConditionalDetrConfig()
|
|
1743
|
-
>>> model = ConditionalDetrForSegmentation(config)
|
|
1716
|
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/conditional_detr-resnet-50-panoptic")
|
|
1717
|
+
>>> model = ConditionalDetrForSegmentation.from_pretrained("facebook/conditional_detr-resnet-50-panoptic")
|
|
1744
1718
|
|
|
1745
1719
|
>>> # prepare image for the model
|
|
1746
1720
|
>>> inputs = image_processor(images=image, return_tensors="pt")
|
|
@@ -1751,89 +1725,88 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
|
1751
1725
|
>>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
|
|
1752
1726
|
>>> # Segmentation results are returned as a list of dictionaries
|
|
1753
1727
|
>>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
|
|
1728
|
+
|
|
1754
1729
|
>>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
|
|
1755
1730
|
>>> panoptic_seg = result[0]["segmentation"]
|
|
1731
|
+
>>> panoptic_seg.shape
|
|
1732
|
+
torch.Size([300, 500])
|
|
1756
1733
|
>>> # Get prediction score and segment_id to class_id mapping of each segment
|
|
1757
1734
|
>>> panoptic_segments_info = result[0]["segments_info"]
|
|
1735
|
+
>>> len(panoptic_segments_info)
|
|
1736
|
+
5
|
|
1758
1737
|
```"""
|
|
1759
1738
|
|
|
1760
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1761
|
-
|
|
1762
1739
|
batch_size, num_channels, height, width = pixel_values.shape
|
|
1763
1740
|
device = pixel_values.device
|
|
1764
1741
|
|
|
1765
1742
|
if pixel_mask is None:
|
|
1766
1743
|
pixel_mask = torch.ones((batch_size, height, width), device=device)
|
|
1767
1744
|
|
|
1768
|
-
|
|
1769
|
-
|
|
1745
|
+
vision_features = self.conditional_detr.model.backbone(pixel_values, pixel_mask)
|
|
1746
|
+
feature_map, mask = vision_features[-1]
|
|
1770
1747
|
|
|
1771
|
-
#
|
|
1772
|
-
feature_map, mask = features[-1]
|
|
1773
|
-
batch_size, num_channels, height, width = feature_map.shape
|
|
1748
|
+
# Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
|
|
1774
1749
|
projected_feature_map = self.conditional_detr.model.input_projection(feature_map)
|
|
1775
|
-
|
|
1776
|
-
# Third, flatten the feature map + object_queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
|
1777
|
-
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
|
1778
1750
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
|
1779
|
-
|
|
1780
|
-
|
|
1751
|
+
spatial_position_embeddings = self.conditional_detr.model.position_embedding(
|
|
1752
|
+
shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
|
|
1753
|
+
)
|
|
1781
1754
|
flattened_mask = mask.flatten(1)
|
|
1782
1755
|
|
|
1783
|
-
# Fourth, sent flattened_features + flattened_mask + object_queries through encoder
|
|
1784
|
-
# flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
|
|
1785
|
-
# flattened_mask is a Tensor of shape (batch_size, height*width)
|
|
1786
1756
|
if encoder_outputs is None:
|
|
1787
1757
|
encoder_outputs = self.conditional_detr.model.encoder(
|
|
1788
1758
|
inputs_embeds=flattened_features,
|
|
1789
1759
|
attention_mask=flattened_mask,
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
output_hidden_states=output_hidden_states,
|
|
1793
|
-
return_dict=return_dict,
|
|
1794
|
-
)
|
|
1795
|
-
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
|
1796
|
-
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
|
1797
|
-
encoder_outputs = BaseModelOutput(
|
|
1798
|
-
last_hidden_state=encoder_outputs[0],
|
|
1799
|
-
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1800
|
-
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1760
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1761
|
+
**kwargs,
|
|
1801
1762
|
)
|
|
1802
1763
|
|
|
1803
|
-
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
queries
|
|
1764
|
+
object_queries_position_embeddings = self.conditional_detr.model.query_position_embeddings.weight.unsqueeze(
|
|
1765
|
+
0
|
|
1766
|
+
).repeat(batch_size, 1, 1)
|
|
1767
|
+
|
|
1768
|
+
# Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
|
|
1769
|
+
if decoder_inputs_embeds is not None:
|
|
1770
|
+
queries = decoder_inputs_embeds
|
|
1771
|
+
else:
|
|
1772
|
+
queries = torch.zeros_like(object_queries_position_embeddings)
|
|
1808
1773
|
|
|
1809
|
-
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
|
|
1810
1774
|
decoder_outputs = self.conditional_detr.model.decoder(
|
|
1811
1775
|
inputs_embeds=queries,
|
|
1812
|
-
attention_mask=
|
|
1813
|
-
|
|
1814
|
-
|
|
1815
|
-
encoder_hidden_states=encoder_outputs
|
|
1776
|
+
attention_mask=decoder_attention_mask,
|
|
1777
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1778
|
+
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
1779
|
+
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
1816
1780
|
encoder_attention_mask=flattened_mask,
|
|
1817
|
-
|
|
1818
|
-
output_hidden_states=output_hidden_states,
|
|
1819
|
-
return_dict=return_dict,
|
|
1781
|
+
**kwargs,
|
|
1820
1782
|
)
|
|
1821
1783
|
|
|
1822
1784
|
sequence_output = decoder_outputs[0]
|
|
1823
1785
|
|
|
1824
|
-
# Sixth, compute logits, pred_boxes and pred_masks
|
|
1825
1786
|
logits = self.conditional_detr.class_labels_classifier(sequence_output)
|
|
1826
1787
|
pred_boxes = self.conditional_detr.bbox_predictor(sequence_output).sigmoid()
|
|
1827
1788
|
|
|
1828
|
-
|
|
1829
|
-
|
|
1789
|
+
height, width = feature_map.shape[-2:]
|
|
1790
|
+
memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view(
|
|
1791
|
+
batch_size, self.config.d_model, height, width
|
|
1792
|
+
)
|
|
1793
|
+
attention_mask = flattened_mask.view(batch_size, height, width)
|
|
1794
|
+
|
|
1795
|
+
if attention_mask is not None:
|
|
1796
|
+
min_dtype = torch.finfo(memory.dtype).min
|
|
1797
|
+
attention_mask = torch.where(
|
|
1798
|
+
attention_mask.unsqueeze(1).unsqueeze(1),
|
|
1799
|
+
torch.tensor(0.0, device=memory.device, dtype=memory.dtype),
|
|
1800
|
+
min_dtype,
|
|
1801
|
+
)
|
|
1830
1802
|
|
|
1831
|
-
|
|
1832
|
-
# important: we need to reverse the mask, since in the original implementation the mask works reversed
|
|
1833
|
-
# bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32)
|
|
1834
|
-
bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask)
|
|
1803
|
+
bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=attention_mask)
|
|
1835
1804
|
|
|
1836
|
-
seg_masks = self.mask_head(
|
|
1805
|
+
seg_masks = self.mask_head(
|
|
1806
|
+
features=projected_feature_map,
|
|
1807
|
+
attention_masks=bbox_mask,
|
|
1808
|
+
fpn_features=[vision_features[2][0], vision_features[1][0], vision_features[0][0]],
|
|
1809
|
+
)
|
|
1837
1810
|
|
|
1838
1811
|
pred_masks = seg_masks.view(
|
|
1839
1812
|
batch_size, self.conditional_detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]
|
|
@@ -1843,20 +1816,13 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
|
1843
1816
|
if labels is not None:
|
|
1844
1817
|
outputs_class, outputs_coord = None, None
|
|
1845
1818
|
if self.config.auxiliary_loss:
|
|
1846
|
-
intermediate = decoder_outputs.intermediate_hidden_states
|
|
1819
|
+
intermediate = decoder_outputs.intermediate_hidden_states
|
|
1847
1820
|
outputs_class = self.conditional_detr.class_labels_classifier(intermediate)
|
|
1848
1821
|
outputs_coord = self.conditional_detr.bbox_predictor(intermediate).sigmoid()
|
|
1849
1822
|
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
|
1850
|
-
logits, labels,
|
|
1823
|
+
logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
|
|
1851
1824
|
)
|
|
1852
1825
|
|
|
1853
|
-
if not return_dict:
|
|
1854
|
-
if auxiliary_outputs is not None:
|
|
1855
|
-
output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
|
|
1856
|
-
else:
|
|
1857
|
-
output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
|
|
1858
|
-
return ((loss, loss_dict) + output) if loss is not None else output
|
|
1859
|
-
|
|
1860
1826
|
return ConditionalDetrSegmentationOutput(
|
|
1861
1827
|
loss=loss,
|
|
1862
1828
|
loss_dict=loss_dict,
|
|
@@ -1874,120 +1840,6 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
|
1874
1840
|
)
|
|
1875
1841
|
|
|
1876
1842
|
|
|
1877
|
-
def _expand(tensor, length: int):
|
|
1878
|
-
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
# Copied from transformers.models.detr.modeling_detr.DetrMaskHeadSmallConv with Detr->ConditionalDetr
|
|
1882
|
-
class ConditionalDetrMaskHeadSmallConv(nn.Module):
|
|
1883
|
-
"""
|
|
1884
|
-
Simple convolutional head, using group norm. Upsampling is done using a FPN approach
|
|
1885
|
-
"""
|
|
1886
|
-
|
|
1887
|
-
def __init__(self, dim, fpn_dims, context_dim):
|
|
1888
|
-
super().__init__()
|
|
1889
|
-
|
|
1890
|
-
if dim % 8 != 0:
|
|
1891
|
-
raise ValueError(
|
|
1892
|
-
"The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
|
|
1893
|
-
" GroupNorm is set to 8"
|
|
1894
|
-
)
|
|
1895
|
-
|
|
1896
|
-
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
|
|
1897
|
-
|
|
1898
|
-
self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
|
|
1899
|
-
self.gn1 = nn.GroupNorm(8, dim)
|
|
1900
|
-
self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
|
|
1901
|
-
self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1])
|
|
1902
|
-
self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
|
|
1903
|
-
self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2])
|
|
1904
|
-
self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
|
|
1905
|
-
self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3])
|
|
1906
|
-
self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
|
|
1907
|
-
self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4])
|
|
1908
|
-
self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
|
|
1909
|
-
|
|
1910
|
-
self.dim = dim
|
|
1911
|
-
|
|
1912
|
-
self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
|
|
1913
|
-
self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
|
|
1914
|
-
self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
|
|
1915
|
-
|
|
1916
|
-
for m in self.modules():
|
|
1917
|
-
if isinstance(m, nn.Conv2d):
|
|
1918
|
-
init.kaiming_uniform_(m.weight, a=1)
|
|
1919
|
-
init.constant_(m.bias, 0)
|
|
1920
|
-
|
|
1921
|
-
def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]):
|
|
1922
|
-
# here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with
|
|
1923
|
-
# the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
|
|
1924
|
-
# We expand the projected feature map to match the number of heads.
|
|
1925
|
-
x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
|
|
1926
|
-
|
|
1927
|
-
x = self.lay1(x)
|
|
1928
|
-
x = self.gn1(x)
|
|
1929
|
-
x = nn.functional.relu(x)
|
|
1930
|
-
x = self.lay2(x)
|
|
1931
|
-
x = self.gn2(x)
|
|
1932
|
-
x = nn.functional.relu(x)
|
|
1933
|
-
|
|
1934
|
-
cur_fpn = self.adapter1(fpns[0])
|
|
1935
|
-
if cur_fpn.size(0) != x.size(0):
|
|
1936
|
-
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
|
1937
|
-
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
1938
|
-
x = self.lay3(x)
|
|
1939
|
-
x = self.gn3(x)
|
|
1940
|
-
x = nn.functional.relu(x)
|
|
1941
|
-
|
|
1942
|
-
cur_fpn = self.adapter2(fpns[1])
|
|
1943
|
-
if cur_fpn.size(0) != x.size(0):
|
|
1944
|
-
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
|
1945
|
-
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
1946
|
-
x = self.lay4(x)
|
|
1947
|
-
x = self.gn4(x)
|
|
1948
|
-
x = nn.functional.relu(x)
|
|
1949
|
-
|
|
1950
|
-
cur_fpn = self.adapter3(fpns[2])
|
|
1951
|
-
if cur_fpn.size(0) != x.size(0):
|
|
1952
|
-
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
|
1953
|
-
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
1954
|
-
x = self.lay5(x)
|
|
1955
|
-
x = self.gn5(x)
|
|
1956
|
-
x = nn.functional.relu(x)
|
|
1957
|
-
|
|
1958
|
-
x = self.out_lay(x)
|
|
1959
|
-
return x
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
# Copied from transformers.models.detr.modeling_detr.DetrMHAttentionMap with Detr->ConditionalDetr
|
|
1963
|
-
class ConditionalDetrMHAttentionMap(nn.Module):
|
|
1964
|
-
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
|
1965
|
-
|
|
1966
|
-
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
|
|
1967
|
-
super().__init__()
|
|
1968
|
-
self.num_heads = num_heads
|
|
1969
|
-
self.hidden_dim = hidden_dim
|
|
1970
|
-
self.dropout = nn.Dropout(dropout)
|
|
1971
|
-
|
|
1972
|
-
self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
|
1973
|
-
self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
|
1974
|
-
|
|
1975
|
-
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
|
|
1976
|
-
|
|
1977
|
-
def forward(self, q, k, mask: Tensor | None = None):
|
|
1978
|
-
q = self.q_linear(q)
|
|
1979
|
-
k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
|
|
1980
|
-
queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
|
|
1981
|
-
keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
|
|
1982
|
-
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
|
|
1983
|
-
|
|
1984
|
-
if mask is not None:
|
|
1985
|
-
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
|
1986
|
-
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
|
|
1987
|
-
weights = self.dropout(weights)
|
|
1988
|
-
return weights
|
|
1989
|
-
|
|
1990
|
-
|
|
1991
1843
|
__all__ = [
|
|
1992
1844
|
"ConditionalDetrForObjectDetection",
|
|
1993
1845
|
"ConditionalDetrForSegmentation",
|