transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -14,32 +14,35 @@
|
|
|
14
14
|
"""PyTorch DETR model."""
|
|
15
15
|
|
|
16
16
|
import math
|
|
17
|
+
from collections.abc import Callable
|
|
17
18
|
from dataclasses import dataclass
|
|
18
19
|
|
|
19
20
|
import torch
|
|
20
|
-
|
|
21
|
+
import torch.nn as nn
|
|
21
22
|
|
|
22
23
|
from ... import initialization as init
|
|
23
24
|
from ...activations import ACT2FN
|
|
24
|
-
from ...
|
|
25
|
+
from ...backbone_utils import load_backbone
|
|
26
|
+
from ...masking_utils import create_bidirectional_mask
|
|
25
27
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
26
|
-
from ...modeling_outputs import
|
|
27
|
-
|
|
28
|
+
from ...modeling_outputs import (
|
|
29
|
+
BaseModelOutput,
|
|
30
|
+
BaseModelOutputWithCrossAttentions,
|
|
31
|
+
Seq2SeqModelOutput,
|
|
32
|
+
)
|
|
33
|
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
34
|
+
from ...processing_utils import Unpack
|
|
35
|
+
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
28
36
|
from ...utils import (
|
|
29
37
|
ModelOutput,
|
|
38
|
+
TransformersKwargs,
|
|
30
39
|
auto_docstring,
|
|
31
|
-
is_timm_available,
|
|
32
40
|
logging,
|
|
33
|
-
requires_backends,
|
|
34
41
|
)
|
|
35
|
-
from ...utils.
|
|
42
|
+
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
36
43
|
from .configuration_detr import DetrConfig
|
|
37
44
|
|
|
38
45
|
|
|
39
|
-
if is_timm_available():
|
|
40
|
-
from timm import create_model
|
|
41
|
-
|
|
42
|
-
|
|
43
46
|
logger = logging.get_logger(__name__)
|
|
44
47
|
|
|
45
48
|
|
|
@@ -178,8 +181,6 @@ class DetrSegmentationOutput(ModelOutput):
|
|
|
178
181
|
encoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
179
182
|
|
|
180
183
|
|
|
181
|
-
# BELOW: utilities copied from
|
|
182
|
-
# https://github.com/facebookresearch/detr/blob/master/backbone.py
|
|
183
184
|
class DetrFrozenBatchNorm2d(nn.Module):
|
|
184
185
|
"""
|
|
185
186
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
@@ -256,47 +257,25 @@ class DetrConvEncoder(nn.Module):
|
|
|
256
257
|
|
|
257
258
|
self.config = config
|
|
258
259
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
# We default to values which were previously hard-coded. This enables configurability from the config
|
|
262
|
-
# using backbone arguments, while keeping the default behavior the same.
|
|
263
|
-
requires_backends(self, ["timm"])
|
|
264
|
-
kwargs = getattr(config, "backbone_kwargs", {})
|
|
265
|
-
kwargs = {} if kwargs is None else kwargs.copy()
|
|
266
|
-
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
|
|
267
|
-
num_channels = kwargs.pop("in_chans", config.num_channels)
|
|
268
|
-
if config.dilation:
|
|
269
|
-
kwargs["output_stride"] = kwargs.get("output_stride", 16)
|
|
270
|
-
backbone = create_model(
|
|
271
|
-
config.backbone,
|
|
272
|
-
pretrained=config.use_pretrained_backbone,
|
|
273
|
-
features_only=True,
|
|
274
|
-
out_indices=out_indices,
|
|
275
|
-
in_chans=num_channels,
|
|
276
|
-
**kwargs,
|
|
277
|
-
)
|
|
278
|
-
else:
|
|
279
|
-
backbone = load_backbone(config)
|
|
260
|
+
backbone = load_backbone(config)
|
|
261
|
+
self.intermediate_channel_sizes = backbone.channels
|
|
280
262
|
|
|
281
263
|
# replace batch norm by frozen batch norm
|
|
282
264
|
with torch.no_grad():
|
|
283
265
|
replace_batch_norm(backbone)
|
|
284
|
-
self.model = backbone
|
|
285
|
-
self.intermediate_channel_sizes = (
|
|
286
|
-
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
|
|
287
|
-
)
|
|
288
266
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
267
|
+
# We used to load with timm library directly instead of the AutoBackbone API
|
|
268
|
+
# so we need to unwrap the `backbone._backbone` module to load weights without mismatch
|
|
269
|
+
is_timm_model = False
|
|
270
|
+
if hasattr(backbone, "_backbone"):
|
|
271
|
+
backbone = backbone._backbone
|
|
272
|
+
is_timm_model = True
|
|
273
|
+
self.model = backbone
|
|
296
274
|
|
|
275
|
+
backbone_model_type = config.backbone_config.model_type
|
|
297
276
|
if "resnet" in backbone_model_type:
|
|
298
277
|
for name, parameter in self.model.named_parameters():
|
|
299
|
-
if
|
|
278
|
+
if is_timm_model:
|
|
300
279
|
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
|
301
280
|
parameter.requires_grad_(False)
|
|
302
281
|
else:
|
|
@@ -305,7 +284,9 @@ class DetrConvEncoder(nn.Module):
|
|
|
305
284
|
|
|
306
285
|
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
|
307
286
|
# send pixel_values through the model to get list of feature maps
|
|
308
|
-
features = self.model(pixel_values)
|
|
287
|
+
features = self.model(pixel_values)
|
|
288
|
+
if isinstance(features, dict):
|
|
289
|
+
features = features.feature_maps
|
|
309
290
|
|
|
310
291
|
out = []
|
|
311
292
|
for feature_map in features:
|
|
@@ -315,61 +296,55 @@ class DetrConvEncoder(nn.Module):
|
|
|
315
296
|
return out
|
|
316
297
|
|
|
317
298
|
|
|
318
|
-
class DetrConvModel(nn.Module):
|
|
319
|
-
"""
|
|
320
|
-
This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
|
|
321
|
-
"""
|
|
322
|
-
|
|
323
|
-
def __init__(self, conv_encoder, position_embedding):
|
|
324
|
-
super().__init__()
|
|
325
|
-
self.conv_encoder = conv_encoder
|
|
326
|
-
self.position_embedding = position_embedding
|
|
327
|
-
|
|
328
|
-
def forward(self, pixel_values, pixel_mask):
|
|
329
|
-
# send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
|
|
330
|
-
out = self.conv_encoder(pixel_values, pixel_mask)
|
|
331
|
-
pos = []
|
|
332
|
-
for feature_map, mask in out:
|
|
333
|
-
# position encoding
|
|
334
|
-
pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
|
|
335
|
-
|
|
336
|
-
return out, pos
|
|
337
|
-
|
|
338
|
-
|
|
339
299
|
class DetrSinePositionEmbedding(nn.Module):
|
|
340
300
|
"""
|
|
341
301
|
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
|
|
342
302
|
need paper, generalized to work on images.
|
|
343
303
|
"""
|
|
344
304
|
|
|
345
|
-
def __init__(
|
|
305
|
+
def __init__(
|
|
306
|
+
self,
|
|
307
|
+
num_position_features: int = 64,
|
|
308
|
+
temperature: int = 10000,
|
|
309
|
+
normalize: bool = False,
|
|
310
|
+
scale: float | None = None,
|
|
311
|
+
):
|
|
346
312
|
super().__init__()
|
|
347
|
-
self.embedding_dim = embedding_dim
|
|
348
|
-
self.temperature = temperature
|
|
349
|
-
self.normalize = normalize
|
|
350
313
|
if scale is not None and normalize is False:
|
|
351
314
|
raise ValueError("normalize should be True if scale is passed")
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
self.
|
|
315
|
+
self.num_position_features = num_position_features
|
|
316
|
+
self.temperature = temperature
|
|
317
|
+
self.normalize = normalize
|
|
318
|
+
self.scale = 2 * math.pi if scale is None else scale
|
|
355
319
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
320
|
+
@compile_compatible_method_lru_cache(maxsize=1)
|
|
321
|
+
def forward(
|
|
322
|
+
self,
|
|
323
|
+
shape: torch.Size,
|
|
324
|
+
device: torch.device | str,
|
|
325
|
+
dtype: torch.dtype,
|
|
326
|
+
mask: torch.Tensor | None = None,
|
|
327
|
+
) -> torch.Tensor:
|
|
328
|
+
if mask is None:
|
|
329
|
+
mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
|
|
330
|
+
y_embed = mask.cumsum(1, dtype=dtype)
|
|
331
|
+
x_embed = mask.cumsum(2, dtype=dtype)
|
|
361
332
|
if self.normalize:
|
|
362
|
-
|
|
363
|
-
|
|
333
|
+
eps = 1e-6
|
|
334
|
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
335
|
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
364
336
|
|
|
365
|
-
dim_t = torch.arange(self.
|
|
366
|
-
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.
|
|
337
|
+
dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
|
|
338
|
+
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
|
|
367
339
|
|
|
368
340
|
pos_x = x_embed[:, :, :, None] / dim_t
|
|
369
341
|
pos_y = y_embed[:, :, :, None] / dim_t
|
|
370
342
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
371
343
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
372
344
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
345
|
+
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
|
|
346
|
+
# expected by the encoder
|
|
347
|
+
pos = pos.flatten(2).permute(0, 2, 1)
|
|
373
348
|
return pos
|
|
374
349
|
|
|
375
350
|
|
|
@@ -383,207 +358,260 @@ class DetrLearnedPositionEmbedding(nn.Module):
|
|
|
383
358
|
self.row_embeddings = nn.Embedding(50, embedding_dim)
|
|
384
359
|
self.column_embeddings = nn.Embedding(50, embedding_dim)
|
|
385
360
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
361
|
+
@compile_compatible_method_lru_cache(maxsize=1)
|
|
362
|
+
def forward(
|
|
363
|
+
self,
|
|
364
|
+
shape: torch.Size,
|
|
365
|
+
device: torch.device | str,
|
|
366
|
+
dtype: torch.dtype,
|
|
367
|
+
mask: torch.Tensor | None = None,
|
|
368
|
+
):
|
|
369
|
+
height, width = shape[-2:]
|
|
370
|
+
width_values = torch.arange(width, device=device)
|
|
371
|
+
height_values = torch.arange(height, device=device)
|
|
390
372
|
x_emb = self.column_embeddings(width_values)
|
|
391
373
|
y_emb = self.row_embeddings(height_values)
|
|
392
374
|
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
|
|
393
375
|
pos = pos.permute(2, 0, 1)
|
|
394
376
|
pos = pos.unsqueeze(0)
|
|
395
|
-
pos = pos.repeat(
|
|
377
|
+
pos = pos.repeat(shape[0], 1, 1, 1)
|
|
378
|
+
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
|
|
379
|
+
# expected by the encoder
|
|
380
|
+
pos = pos.flatten(2).permute(0, 2, 1)
|
|
396
381
|
return pos
|
|
397
382
|
|
|
398
383
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
384
|
+
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
|
385
|
+
def eager_attention_forward(
|
|
386
|
+
module: nn.Module,
|
|
387
|
+
query: torch.Tensor,
|
|
388
|
+
key: torch.Tensor,
|
|
389
|
+
value: torch.Tensor,
|
|
390
|
+
attention_mask: torch.Tensor | None,
|
|
391
|
+
scaling: float | None = None,
|
|
392
|
+
dropout: float = 0.0,
|
|
393
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
394
|
+
):
|
|
395
|
+
if scaling is None:
|
|
396
|
+
scaling = query.size(-1) ** -0.5
|
|
397
|
+
|
|
398
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
399
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
400
|
+
|
|
401
|
+
if attention_mask is not None:
|
|
402
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
403
|
+
attn_weights = attn_weights + attention_mask
|
|
404
|
+
|
|
405
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
406
|
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
408
407
|
|
|
409
|
-
|
|
408
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
409
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
410
410
|
|
|
411
|
+
return attn_output, attn_weights
|
|
411
412
|
|
|
412
|
-
|
|
413
|
+
|
|
414
|
+
class DetrSelfAttention(nn.Module):
|
|
413
415
|
"""
|
|
414
|
-
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
416
|
+
Multi-headed self-attention from 'Attention Is All You Need' paper.
|
|
415
417
|
|
|
416
|
-
|
|
418
|
+
In DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
|
|
417
419
|
"""
|
|
418
420
|
|
|
419
421
|
def __init__(
|
|
420
422
|
self,
|
|
421
|
-
|
|
422
|
-
|
|
423
|
+
config: DetrConfig,
|
|
424
|
+
hidden_size: int,
|
|
425
|
+
num_attention_heads: int,
|
|
423
426
|
dropout: float = 0.0,
|
|
424
427
|
bias: bool = True,
|
|
425
428
|
):
|
|
426
429
|
super().__init__()
|
|
427
|
-
self.
|
|
428
|
-
self.
|
|
429
|
-
self.dropout = dropout
|
|
430
|
-
self.head_dim = embed_dim // num_heads
|
|
431
|
-
if self.head_dim * num_heads != self.embed_dim:
|
|
432
|
-
raise ValueError(
|
|
433
|
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
434
|
-
f" {num_heads})."
|
|
435
|
-
)
|
|
430
|
+
self.config = config
|
|
431
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
436
432
|
self.scaling = self.head_dim**-0.5
|
|
433
|
+
self.attention_dropout = dropout
|
|
434
|
+
self.is_causal = False
|
|
437
435
|
|
|
438
|
-
self.k_proj = nn.Linear(
|
|
439
|
-
self.v_proj = nn.Linear(
|
|
440
|
-
self.q_proj = nn.Linear(
|
|
441
|
-
self.
|
|
442
|
-
|
|
443
|
-
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
|
444
|
-
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
445
|
-
|
|
446
|
-
def with_pos_embed(self, tensor: torch.Tensor, object_queries: Tensor | None):
|
|
447
|
-
return tensor if object_queries is None else tensor + object_queries
|
|
436
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
437
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
438
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
439
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
448
440
|
|
|
449
441
|
def forward(
|
|
450
442
|
self,
|
|
451
443
|
hidden_states: torch.Tensor,
|
|
452
444
|
attention_mask: torch.Tensor | None = None,
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
"""
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
is_cross_attention = key_value_states is not None
|
|
462
|
-
batch_size, target_len, embed_dim = hidden_states.size()
|
|
463
|
-
|
|
464
|
-
# add position embeddings to the hidden states before projecting to queries and keys
|
|
465
|
-
if object_queries is not None:
|
|
466
|
-
hidden_states_original = hidden_states
|
|
467
|
-
hidden_states = self.with_pos_embed(hidden_states, object_queries)
|
|
468
|
-
|
|
469
|
-
# add key-value position embeddings to the key value states
|
|
470
|
-
if spatial_position_embeddings is not None:
|
|
471
|
-
key_value_states_original = key_value_states
|
|
472
|
-
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
|
|
473
|
-
|
|
474
|
-
# get query proj
|
|
475
|
-
query_states = self.q_proj(hidden_states) * self.scaling
|
|
476
|
-
# get key, value proj
|
|
477
|
-
if is_cross_attention:
|
|
478
|
-
# cross_attentions
|
|
479
|
-
key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
|
|
480
|
-
value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
|
|
481
|
-
else:
|
|
482
|
-
# self_attention
|
|
483
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
|
|
484
|
-
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
|
|
445
|
+
position_embeddings: torch.Tensor | None = None,
|
|
446
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
447
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
448
|
+
"""
|
|
449
|
+
Position embeddings are added to both queries and keys (but not values).
|
|
450
|
+
"""
|
|
451
|
+
input_shape = hidden_states.shape[:-1]
|
|
452
|
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
485
453
|
|
|
486
|
-
|
|
487
|
-
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
|
|
488
|
-
key_states = key_states.view(*proj_shape)
|
|
489
|
-
value_states = value_states.view(*proj_shape)
|
|
454
|
+
query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
|
|
490
455
|
|
|
491
|
-
|
|
456
|
+
query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
457
|
+
key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
458
|
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
492
459
|
|
|
493
|
-
|
|
460
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
461
|
+
self.config._attn_implementation, eager_attention_forward
|
|
462
|
+
)
|
|
494
463
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
464
|
+
attn_output, attn_weights = attention_interface(
|
|
465
|
+
self,
|
|
466
|
+
query_states,
|
|
467
|
+
key_states,
|
|
468
|
+
value_states,
|
|
469
|
+
attention_mask,
|
|
470
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
471
|
+
scaling=self.scaling,
|
|
472
|
+
**kwargs,
|
|
473
|
+
)
|
|
500
474
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
|
505
|
-
f" {attention_mask.size()}"
|
|
506
|
-
)
|
|
507
|
-
if attention_mask.dtype == torch.bool:
|
|
508
|
-
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
|
|
509
|
-
attention_mask, -torch.inf
|
|
510
|
-
)
|
|
511
|
-
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
|
512
|
-
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
|
513
|
-
|
|
514
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
515
|
-
|
|
516
|
-
if output_attentions:
|
|
517
|
-
# this operation is a bit awkward, but it's required to
|
|
518
|
-
# make sure that attn_weights keeps its gradient.
|
|
519
|
-
# In order to do so, attn_weights have to reshaped
|
|
520
|
-
# twice and have to be reused in the following
|
|
521
|
-
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
|
522
|
-
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
|
523
|
-
else:
|
|
524
|
-
attn_weights_reshaped = None
|
|
475
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
476
|
+
attn_output = self.o_proj(attn_output)
|
|
477
|
+
return attn_output, attn_weights
|
|
525
478
|
|
|
526
|
-
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
527
479
|
|
|
528
|
-
|
|
480
|
+
class DetrCrossAttention(nn.Module):
|
|
481
|
+
"""
|
|
482
|
+
Multi-headed cross-attention from 'Attention Is All You Need' paper.
|
|
529
483
|
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
484
|
+
In DETR, queries get their own position embeddings, while keys get encoder position embeddings.
|
|
485
|
+
Values don't get any position embeddings.
|
|
486
|
+
"""
|
|
487
|
+
|
|
488
|
+
def __init__(
|
|
489
|
+
self,
|
|
490
|
+
config: DetrConfig,
|
|
491
|
+
hidden_size: int,
|
|
492
|
+
num_attention_heads: int,
|
|
493
|
+
dropout: float = 0.0,
|
|
494
|
+
bias: bool = True,
|
|
495
|
+
):
|
|
496
|
+
super().__init__()
|
|
497
|
+
self.config = config
|
|
498
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
499
|
+
self.scaling = self.head_dim**-0.5
|
|
500
|
+
self.attention_dropout = dropout
|
|
501
|
+
self.is_causal = False
|
|
502
|
+
|
|
503
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
504
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
505
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
506
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
507
|
+
|
|
508
|
+
def forward(
|
|
509
|
+
self,
|
|
510
|
+
hidden_states: torch.Tensor,
|
|
511
|
+
key_value_states: torch.Tensor,
|
|
512
|
+
attention_mask: torch.Tensor | None = None,
|
|
513
|
+
position_embeddings: torch.Tensor | None = None,
|
|
514
|
+
encoder_position_embeddings: torch.Tensor | None = None,
|
|
515
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
516
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
517
|
+
"""
|
|
518
|
+
Position embeddings logic:
|
|
519
|
+
- Queries get position_embeddings
|
|
520
|
+
- Keys get encoder_position_embeddings
|
|
521
|
+
- Values don't get any position embeddings
|
|
522
|
+
"""
|
|
523
|
+
query_input_shape = hidden_states.shape[:-1]
|
|
524
|
+
query_hidden_shape = (*query_input_shape, -1, self.head_dim)
|
|
535
525
|
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
|
|
526
|
+
kv_input_shape = key_value_states.shape[:-1]
|
|
527
|
+
kv_hidden_shape = (*kv_input_shape, -1, self.head_dim)
|
|
539
528
|
|
|
540
|
-
|
|
529
|
+
query_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
|
|
530
|
+
key_input = (
|
|
531
|
+
key_value_states + encoder_position_embeddings
|
|
532
|
+
if encoder_position_embeddings is not None
|
|
533
|
+
else key_value_states
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
query_states = self.q_proj(query_input).view(query_hidden_shape).transpose(1, 2)
|
|
537
|
+
key_states = self.k_proj(key_input).view(kv_hidden_shape).transpose(1, 2)
|
|
538
|
+
value_states = self.v_proj(key_value_states).view(kv_hidden_shape).transpose(1, 2)
|
|
539
|
+
|
|
540
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
541
|
+
self.config._attn_implementation, eager_attention_forward
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
attn_output, attn_weights = attention_interface(
|
|
545
|
+
self,
|
|
546
|
+
query_states,
|
|
547
|
+
key_states,
|
|
548
|
+
value_states,
|
|
549
|
+
attention_mask,
|
|
550
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
551
|
+
scaling=self.scaling,
|
|
552
|
+
**kwargs,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
|
|
556
|
+
attn_output = self.o_proj(attn_output)
|
|
557
|
+
return attn_output, attn_weights
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
class DetrMLP(nn.Module):
|
|
561
|
+
def __init__(self, config: DetrConfig, hidden_size: int, intermediate_size: int):
|
|
562
|
+
super().__init__()
|
|
563
|
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
|
564
|
+
self.fc2 = nn.Linear(intermediate_size, hidden_size)
|
|
565
|
+
self.activation_fn = ACT2FN[config.activation_function]
|
|
566
|
+
self.activation_dropout = config.activation_dropout
|
|
567
|
+
self.dropout = config.dropout
|
|
541
568
|
|
|
542
|
-
|
|
569
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
570
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
571
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
572
|
+
hidden_states = self.fc2(hidden_states)
|
|
573
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
574
|
+
return hidden_states
|
|
543
575
|
|
|
544
576
|
|
|
545
|
-
class DetrEncoderLayer(
|
|
577
|
+
class DetrEncoderLayer(GradientCheckpointingLayer):
|
|
546
578
|
def __init__(self, config: DetrConfig):
|
|
547
579
|
super().__init__()
|
|
548
|
-
self.
|
|
549
|
-
self.self_attn =
|
|
550
|
-
|
|
551
|
-
|
|
580
|
+
self.hidden_size = config.d_model
|
|
581
|
+
self.self_attn = DetrSelfAttention(
|
|
582
|
+
config=config,
|
|
583
|
+
hidden_size=self.hidden_size,
|
|
584
|
+
num_attention_heads=config.encoder_attention_heads,
|
|
552
585
|
dropout=config.attention_dropout,
|
|
553
586
|
)
|
|
554
|
-
self.self_attn_layer_norm = nn.LayerNorm(self.
|
|
587
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
555
588
|
self.dropout = config.dropout
|
|
556
|
-
self.
|
|
557
|
-
self.
|
|
558
|
-
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
|
559
|
-
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
|
560
|
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
589
|
+
self.mlp = DetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
|
|
590
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
561
591
|
|
|
562
592
|
def forward(
|
|
563
593
|
self,
|
|
564
594
|
hidden_states: torch.Tensor,
|
|
565
595
|
attention_mask: torch.Tensor,
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
):
|
|
596
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
597
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
598
|
+
) -> torch.Tensor:
|
|
569
599
|
"""
|
|
570
600
|
Args:
|
|
571
|
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len,
|
|
601
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
572
602
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
573
603
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
574
604
|
values.
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
579
|
-
returned tensors for more detail.
|
|
605
|
+
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
606
|
+
Spatial position embeddings (2D positional encodings of image locations), to be added to both
|
|
607
|
+
the queries and keys in self-attention (but not to values).
|
|
580
608
|
"""
|
|
581
609
|
residual = hidden_states
|
|
582
|
-
hidden_states,
|
|
610
|
+
hidden_states, _ = self.self_attn(
|
|
583
611
|
hidden_states=hidden_states,
|
|
584
612
|
attention_mask=attention_mask,
|
|
585
|
-
|
|
586
|
-
|
|
613
|
+
position_embeddings=spatial_position_embeddings,
|
|
614
|
+
**kwargs,
|
|
587
615
|
)
|
|
588
616
|
|
|
589
617
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
@@ -591,12 +619,7 @@ class DetrEncoderLayer(nn.Module):
|
|
|
591
619
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
592
620
|
|
|
593
621
|
residual = hidden_states
|
|
594
|
-
hidden_states = self.
|
|
595
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
596
|
-
|
|
597
|
-
hidden_states = self.fc2(hidden_states)
|
|
598
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
599
|
-
|
|
622
|
+
hidden_states = self.mlp(hidden_states)
|
|
600
623
|
hidden_states = residual + hidden_states
|
|
601
624
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
602
625
|
|
|
@@ -605,78 +628,69 @@ class DetrEncoderLayer(nn.Module):
|
|
|
605
628
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
606
629
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
607
630
|
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
if output_attentions:
|
|
611
|
-
outputs += (attn_weights,)
|
|
612
|
-
|
|
613
|
-
return outputs
|
|
631
|
+
return hidden_states
|
|
614
632
|
|
|
615
633
|
|
|
616
634
|
class DetrDecoderLayer(GradientCheckpointingLayer):
|
|
617
635
|
def __init__(self, config: DetrConfig):
|
|
618
636
|
super().__init__()
|
|
619
|
-
self.
|
|
637
|
+
self.hidden_size = config.d_model
|
|
620
638
|
|
|
621
|
-
self.self_attn =
|
|
622
|
-
|
|
623
|
-
|
|
639
|
+
self.self_attn = DetrSelfAttention(
|
|
640
|
+
config=config,
|
|
641
|
+
hidden_size=self.hidden_size,
|
|
642
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
624
643
|
dropout=config.attention_dropout,
|
|
625
644
|
)
|
|
626
645
|
self.dropout = config.dropout
|
|
627
|
-
self.activation_fn = ACT2FN[config.activation_function]
|
|
628
|
-
self.activation_dropout = config.activation_dropout
|
|
629
646
|
|
|
630
|
-
self.self_attn_layer_norm = nn.LayerNorm(self.
|
|
631
|
-
self.encoder_attn =
|
|
632
|
-
|
|
633
|
-
|
|
647
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
648
|
+
self.encoder_attn = DetrCrossAttention(
|
|
649
|
+
config=config,
|
|
650
|
+
hidden_size=self.hidden_size,
|
|
651
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
634
652
|
dropout=config.attention_dropout,
|
|
635
653
|
)
|
|
636
|
-
self.encoder_attn_layer_norm = nn.LayerNorm(self.
|
|
637
|
-
self.
|
|
638
|
-
self.
|
|
639
|
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
654
|
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
655
|
+
self.mlp = DetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
|
|
656
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
640
657
|
|
|
641
658
|
def forward(
|
|
642
659
|
self,
|
|
643
660
|
hidden_states: torch.Tensor,
|
|
644
661
|
attention_mask: torch.Tensor | None = None,
|
|
645
|
-
|
|
646
|
-
|
|
662
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
663
|
+
object_queries_position_embeddings: torch.Tensor | None = None,
|
|
647
664
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
648
665
|
encoder_attention_mask: torch.Tensor | None = None,
|
|
649
|
-
|
|
650
|
-
):
|
|
666
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
667
|
+
) -> torch.Tensor:
|
|
651
668
|
"""
|
|
652
669
|
Args:
|
|
653
|
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len,
|
|
670
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
654
671
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
655
672
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
656
673
|
values.
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
674
|
+
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
675
|
+
Spatial position embeddings (2D positional encodings from encoder) that are added to the keys only
|
|
676
|
+
in the cross-attention layer (not to values).
|
|
677
|
+
object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
678
|
+
Position embeddings for the object query slots. In self-attention, these are added to both queries
|
|
679
|
+
and keys (not values). In cross-attention, these are added to queries only (not to keys or values).
|
|
663
680
|
encoder_hidden_states (`torch.FloatTensor`):
|
|
664
|
-
cross attention input to the layer of shape `(batch, seq_len,
|
|
681
|
+
cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
665
682
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
666
683
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
667
684
|
values.
|
|
668
|
-
output_attentions (`bool`, *optional*):
|
|
669
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
670
|
-
returned tensors for more detail.
|
|
671
685
|
"""
|
|
672
686
|
residual = hidden_states
|
|
673
687
|
|
|
674
688
|
# Self Attention
|
|
675
|
-
hidden_states,
|
|
689
|
+
hidden_states, _ = self.self_attn(
|
|
676
690
|
hidden_states=hidden_states,
|
|
677
|
-
|
|
691
|
+
position_embeddings=object_queries_position_embeddings,
|
|
678
692
|
attention_mask=attention_mask,
|
|
679
|
-
|
|
693
|
+
**kwargs,
|
|
680
694
|
)
|
|
681
695
|
|
|
682
696
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
@@ -684,17 +698,16 @@ class DetrDecoderLayer(GradientCheckpointingLayer):
|
|
|
684
698
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
685
699
|
|
|
686
700
|
# Cross-Attention Block
|
|
687
|
-
cross_attn_weights = None
|
|
688
701
|
if encoder_hidden_states is not None:
|
|
689
702
|
residual = hidden_states
|
|
690
703
|
|
|
691
|
-
hidden_states,
|
|
704
|
+
hidden_states, _ = self.encoder_attn(
|
|
692
705
|
hidden_states=hidden_states,
|
|
693
|
-
object_queries=query_position_embeddings,
|
|
694
706
|
key_value_states=encoder_hidden_states,
|
|
695
707
|
attention_mask=encoder_attention_mask,
|
|
696
|
-
|
|
697
|
-
|
|
708
|
+
position_embeddings=object_queries_position_embeddings,
|
|
709
|
+
encoder_position_embeddings=spatial_position_embeddings,
|
|
710
|
+
**kwargs,
|
|
698
711
|
)
|
|
699
712
|
|
|
700
713
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
@@ -703,19 +716,164 @@ class DetrDecoderLayer(GradientCheckpointingLayer):
|
|
|
703
716
|
|
|
704
717
|
# Fully Connected
|
|
705
718
|
residual = hidden_states
|
|
706
|
-
hidden_states = self.
|
|
707
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
708
|
-
hidden_states = self.fc2(hidden_states)
|
|
709
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
719
|
+
hidden_states = self.mlp(hidden_states)
|
|
710
720
|
hidden_states = residual + hidden_states
|
|
711
721
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
712
722
|
|
|
713
|
-
|
|
723
|
+
return hidden_states
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
class DetrConvBlock(nn.Module):
|
|
727
|
+
"""Basic conv block: Conv3x3 -> GroupNorm -> Activation."""
|
|
728
|
+
|
|
729
|
+
def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"):
|
|
730
|
+
super().__init__()
|
|
731
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
732
|
+
self.norm = nn.GroupNorm(min(8, out_channels), out_channels)
|
|
733
|
+
self.activation = ACT2FN[activation]
|
|
734
|
+
|
|
735
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
736
|
+
return self.activation(self.norm(self.conv(x)))
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
class DetrFPNFusionStage(nn.Module):
|
|
740
|
+
"""Single FPN fusion stage combining low-resolution features with high-resolution FPN features."""
|
|
741
|
+
|
|
742
|
+
def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"):
|
|
743
|
+
super().__init__()
|
|
744
|
+
self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1)
|
|
745
|
+
self.refine = DetrConvBlock(current_channels, output_channels, activation)
|
|
746
|
+
|
|
747
|
+
def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor:
|
|
748
|
+
"""
|
|
749
|
+
Args:
|
|
750
|
+
features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in)
|
|
751
|
+
fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out)
|
|
752
|
+
|
|
753
|
+
Returns:
|
|
754
|
+
Fused and refined features, shape (B*Q, output_channels, H_out, W_out)
|
|
755
|
+
"""
|
|
756
|
+
fpn_features = self.fpn_adapter(fpn_features)
|
|
757
|
+
features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest")
|
|
758
|
+
return self.refine(fpn_features + features)
|
|
759
|
+
|
|
714
760
|
|
|
715
|
-
|
|
716
|
-
|
|
761
|
+
class DetrMaskHeadSmallConv(nn.Module):
|
|
762
|
+
"""
|
|
763
|
+
Segmentation mask head that generates per-query masks using FPN-based progressive upsampling.
|
|
717
764
|
|
|
718
|
-
|
|
765
|
+
Combines attention maps (spatial localization) with encoder features (semantics) and progressively
|
|
766
|
+
upsamples through multiple scales, fusing with FPN features for high-resolution detail.
|
|
767
|
+
"""
|
|
768
|
+
|
|
769
|
+
def __init__(
|
|
770
|
+
self,
|
|
771
|
+
input_channels: int,
|
|
772
|
+
fpn_channels: list[int],
|
|
773
|
+
hidden_size: int,
|
|
774
|
+
activation_function: str = "relu",
|
|
775
|
+
):
|
|
776
|
+
super().__init__()
|
|
777
|
+
if input_channels % 8 != 0:
|
|
778
|
+
raise ValueError(f"input_channels must be divisible by 8, got {input_channels}")
|
|
779
|
+
|
|
780
|
+
self.conv1 = DetrConvBlock(input_channels, input_channels, activation_function)
|
|
781
|
+
self.conv2 = DetrConvBlock(input_channels, hidden_size // 2, activation_function)
|
|
782
|
+
|
|
783
|
+
# Progressive channel reduction: /2 -> /4 -> /8 -> /16
|
|
784
|
+
self.fpn_stages = nn.ModuleList(
|
|
785
|
+
[
|
|
786
|
+
DetrFPNFusionStage(fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function),
|
|
787
|
+
DetrFPNFusionStage(fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function),
|
|
788
|
+
DetrFPNFusionStage(fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function),
|
|
789
|
+
]
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1)
|
|
793
|
+
|
|
794
|
+
def forward(
|
|
795
|
+
self,
|
|
796
|
+
features: torch.Tensor,
|
|
797
|
+
attention_masks: torch.Tensor,
|
|
798
|
+
fpn_features: list[torch.Tensor],
|
|
799
|
+
) -> torch.Tensor:
|
|
800
|
+
"""
|
|
801
|
+
Args:
|
|
802
|
+
features: Encoder output features, shape (batch_size, hidden_size, H, W)
|
|
803
|
+
attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W)
|
|
804
|
+
fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W)
|
|
805
|
+
|
|
806
|
+
Returns:
|
|
807
|
+
Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W)
|
|
808
|
+
"""
|
|
809
|
+
num_queries = attention_masks.shape[1]
|
|
810
|
+
|
|
811
|
+
# Expand to (batch_size * num_queries) dimension
|
|
812
|
+
features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1)
|
|
813
|
+
attention_masks = attention_masks.flatten(0, 1)
|
|
814
|
+
fpn_features = [
|
|
815
|
+
fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features
|
|
816
|
+
]
|
|
817
|
+
|
|
818
|
+
hidden_states = torch.cat([features, attention_masks], dim=1)
|
|
819
|
+
hidden_states = self.conv1(hidden_states)
|
|
820
|
+
hidden_states = self.conv2(hidden_states)
|
|
821
|
+
|
|
822
|
+
for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features):
|
|
823
|
+
hidden_states = fpn_stage(hidden_states, fpn_feat)
|
|
824
|
+
|
|
825
|
+
return self.output_conv(hidden_states)
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
class DetrMHAttentionMap(nn.Module):
|
|
829
|
+
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
|
830
|
+
|
|
831
|
+
def __init__(
|
|
832
|
+
self,
|
|
833
|
+
hidden_size: int,
|
|
834
|
+
num_attention_heads: int,
|
|
835
|
+
dropout: float = 0.0,
|
|
836
|
+
bias: bool = True,
|
|
837
|
+
):
|
|
838
|
+
super().__init__()
|
|
839
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
840
|
+
self.scaling = self.head_dim**-0.5
|
|
841
|
+
self.attention_dropout = dropout
|
|
842
|
+
|
|
843
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
844
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
845
|
+
|
|
846
|
+
def forward(
|
|
847
|
+
self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None
|
|
848
|
+
):
|
|
849
|
+
query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim)
|
|
850
|
+
key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:])
|
|
851
|
+
|
|
852
|
+
query_states = self.q_proj(query_states).view(query_hidden_shape)
|
|
853
|
+
key_states = nn.functional.conv2d(
|
|
854
|
+
key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias
|
|
855
|
+
).view(key_hidden_shape)
|
|
856
|
+
|
|
857
|
+
batch_size, num_queries, num_heads, head_dim = query_states.shape
|
|
858
|
+
_, _, _, height, width = key_states.shape
|
|
859
|
+
query_shape = (batch_size * num_heads, num_queries, head_dim)
|
|
860
|
+
key_shape = (batch_size * num_heads, height * width, head_dim)
|
|
861
|
+
attn_weights_shape = (batch_size, num_heads, num_queries, height, width)
|
|
862
|
+
|
|
863
|
+
query = query_states.transpose(1, 2).contiguous().view(query_shape)
|
|
864
|
+
key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape)
|
|
865
|
+
|
|
866
|
+
attn_weights = (
|
|
867
|
+
(torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2)
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
if attention_mask is not None:
|
|
871
|
+
attn_weights = attn_weights + attention_mask
|
|
872
|
+
|
|
873
|
+
attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size())
|
|
874
|
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
875
|
+
|
|
876
|
+
return attn_weights
|
|
719
877
|
|
|
720
878
|
|
|
721
879
|
@auto_docstring
|
|
@@ -725,21 +883,36 @@ class DetrPreTrainedModel(PreTrainedModel):
|
|
|
725
883
|
main_input_name = "pixel_values"
|
|
726
884
|
input_modalities = ("image",)
|
|
727
885
|
_no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
|
|
886
|
+
supports_gradient_checkpointing = True
|
|
887
|
+
_supports_sdpa = True
|
|
888
|
+
_supports_flash_attn = True
|
|
889
|
+
_supports_attention_backend = True
|
|
890
|
+
_supports_flex_attn = True # Uses create_bidirectional_masks for attention masking
|
|
891
|
+
_keys_to_ignore_on_load_unexpected = [
|
|
892
|
+
r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
|
|
893
|
+
]
|
|
728
894
|
|
|
729
895
|
@torch.no_grad()
|
|
730
896
|
def _init_weights(self, module):
|
|
731
897
|
std = self.config.init_std
|
|
732
898
|
xavier_std = self.config.init_xavier_std
|
|
733
899
|
|
|
734
|
-
if isinstance(module,
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
900
|
+
if isinstance(module, DetrMaskHeadSmallConv):
|
|
901
|
+
# DetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers
|
|
902
|
+
for m in module.modules():
|
|
903
|
+
if isinstance(m, nn.Conv2d):
|
|
904
|
+
init.kaiming_uniform_(m.weight, a=1)
|
|
905
|
+
if m.bias is not None:
|
|
906
|
+
init.constant_(m.bias, 0)
|
|
907
|
+
elif isinstance(module, DetrMHAttentionMap):
|
|
908
|
+
init.zeros_(module.k_proj.bias)
|
|
909
|
+
init.zeros_(module.q_proj.bias)
|
|
910
|
+
init.xavier_uniform_(module.k_proj.weight, gain=xavier_std)
|
|
911
|
+
init.xavier_uniform_(module.q_proj.weight, gain=xavier_std)
|
|
739
912
|
elif isinstance(module, DetrLearnedPositionEmbedding):
|
|
740
913
|
init.uniform_(module.row_embeddings.weight)
|
|
741
914
|
init.uniform_(module.column_embeddings.weight)
|
|
742
|
-
|
|
915
|
+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
743
916
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
744
917
|
if module.bias is not None:
|
|
745
918
|
init.zeros_(module.bias)
|
|
@@ -755,47 +928,36 @@ class DetrPreTrainedModel(PreTrainedModel):
|
|
|
755
928
|
|
|
756
929
|
class DetrEncoder(DetrPreTrainedModel):
|
|
757
930
|
"""
|
|
758
|
-
Transformer encoder
|
|
759
|
-
[`DetrEncoderLayer`].
|
|
760
|
-
|
|
761
|
-
The encoder updates the flattened feature map through multiple self-attention layers.
|
|
762
|
-
|
|
763
|
-
Small tweak for DETR:
|
|
764
|
-
|
|
765
|
-
- object_queries are added to the forward pass.
|
|
931
|
+
Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of
|
|
932
|
+
[`DetrEncoderLayer`] modules.
|
|
766
933
|
|
|
767
934
|
Args:
|
|
768
|
-
config:
|
|
935
|
+
config (`DetrConfig`): Model configuration object.
|
|
769
936
|
"""
|
|
770
937
|
|
|
938
|
+
_can_record_outputs = {"hidden_states": DetrEncoderLayer, "attentions": DetrSelfAttention}
|
|
939
|
+
|
|
771
940
|
def __init__(self, config: DetrConfig):
|
|
772
941
|
super().__init__(config)
|
|
773
942
|
|
|
774
943
|
self.dropout = config.dropout
|
|
775
|
-
self.layerdrop = config.encoder_layerdrop
|
|
776
|
-
|
|
777
944
|
self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
778
945
|
|
|
779
|
-
# in the original DETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
|
|
780
|
-
|
|
781
946
|
# Initialize weights and apply final processing
|
|
782
947
|
self.post_init()
|
|
783
948
|
|
|
949
|
+
@check_model_inputs()
|
|
784
950
|
def forward(
|
|
785
951
|
self,
|
|
786
952
|
inputs_embeds=None,
|
|
787
953
|
attention_mask=None,
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
return_dict=None,
|
|
792
|
-
**kwargs,
|
|
793
|
-
):
|
|
954
|
+
spatial_position_embeddings=None,
|
|
955
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
956
|
+
) -> BaseModelOutput:
|
|
794
957
|
r"""
|
|
795
958
|
Args:
|
|
796
959
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
797
960
|
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
|
798
|
-
|
|
799
961
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
800
962
|
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
|
801
963
|
|
|
@@ -803,112 +965,67 @@ class DetrEncoder(DetrPreTrainedModel):
|
|
|
803
965
|
- 0 for pixel features that are padding (i.e. **masked**).
|
|
804
966
|
|
|
805
967
|
[What are attention masks?](../glossary#attention-mask)
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
Object queries that are added to the queries in each self-attention layer.
|
|
809
|
-
|
|
810
|
-
output_attentions (`bool`, *optional*):
|
|
811
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
812
|
-
returned tensors for more detail.
|
|
813
|
-
output_hidden_states (`bool`, *optional*):
|
|
814
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
815
|
-
for more detail.
|
|
816
|
-
return_dict (`bool`, *optional*):
|
|
817
|
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
968
|
+
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
969
|
+
Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
|
|
818
970
|
"""
|
|
819
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
820
|
-
output_hidden_states = (
|
|
821
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
822
|
-
)
|
|
823
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
824
|
-
|
|
825
971
|
hidden_states = inputs_embeds
|
|
826
972
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
827
973
|
|
|
828
974
|
# expand attention_mask
|
|
829
975
|
if attention_mask is not None:
|
|
830
976
|
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
831
|
-
attention_mask =
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
if output_hidden_states:
|
|
837
|
-
encoder_states = encoder_states + (hidden_states,)
|
|
838
|
-
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
|
839
|
-
to_drop = False
|
|
840
|
-
if self.training:
|
|
841
|
-
dropout_probability = torch.rand([])
|
|
842
|
-
if dropout_probability < self.layerdrop: # skip the layer
|
|
843
|
-
to_drop = True
|
|
844
|
-
|
|
845
|
-
if to_drop:
|
|
846
|
-
layer_outputs = (None, None)
|
|
847
|
-
else:
|
|
848
|
-
# we add object_queries as extra input to the encoder_layer
|
|
849
|
-
layer_outputs = encoder_layer(
|
|
850
|
-
hidden_states,
|
|
851
|
-
attention_mask,
|
|
852
|
-
object_queries=object_queries,
|
|
853
|
-
output_attentions=output_attentions,
|
|
854
|
-
)
|
|
855
|
-
|
|
856
|
-
hidden_states = layer_outputs[0]
|
|
857
|
-
|
|
858
|
-
if output_attentions:
|
|
859
|
-
all_attentions = all_attentions + (layer_outputs[1],)
|
|
860
|
-
|
|
861
|
-
if output_hidden_states:
|
|
862
|
-
encoder_states = encoder_states + (hidden_states,)
|
|
863
|
-
|
|
864
|
-
if not return_dict:
|
|
865
|
-
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
|
866
|
-
return BaseModelOutput(
|
|
867
|
-
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
|
868
|
-
)
|
|
869
|
-
|
|
977
|
+
attention_mask = create_bidirectional_mask(
|
|
978
|
+
config=self.config,
|
|
979
|
+
input_embeds=inputs_embeds,
|
|
980
|
+
attention_mask=attention_mask,
|
|
981
|
+
)
|
|
870
982
|
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
983
|
+
for encoder_layer in self.layers:
|
|
984
|
+
# we add spatial_position_embeddings as extra input to the encoder_layer
|
|
985
|
+
hidden_states = encoder_layer(
|
|
986
|
+
hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs
|
|
987
|
+
)
|
|
874
988
|
|
|
875
|
-
|
|
989
|
+
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
876
990
|
|
|
877
|
-
Some small tweaks for DETR:
|
|
878
991
|
|
|
879
|
-
|
|
880
|
-
|
|
992
|
+
class DetrDecoder(DetrPreTrainedModel):
|
|
993
|
+
"""
|
|
994
|
+
Transformer decoder that refines a set of object queries. It is composed of a stack of [`DetrDecoderLayer`] modules,
|
|
995
|
+
which apply self-attention to the queries and cross-attention to the encoder's outputs.
|
|
881
996
|
|
|
882
997
|
Args:
|
|
883
|
-
config:
|
|
998
|
+
config (`DetrConfig`): Model configuration object.
|
|
884
999
|
"""
|
|
885
1000
|
|
|
1001
|
+
_can_record_outputs = {
|
|
1002
|
+
"hidden_states": DetrDecoderLayer,
|
|
1003
|
+
"attentions": DetrSelfAttention,
|
|
1004
|
+
"cross_attentions": DetrCrossAttention,
|
|
1005
|
+
}
|
|
1006
|
+
|
|
886
1007
|
def __init__(self, config: DetrConfig):
|
|
887
1008
|
super().__init__(config)
|
|
888
1009
|
self.dropout = config.dropout
|
|
889
|
-
self.layerdrop = config.decoder_layerdrop
|
|
890
1010
|
|
|
891
1011
|
self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
|
|
892
1012
|
# in DETR, the decoder uses layernorm after the last decoder layer output
|
|
893
1013
|
self.layernorm = nn.LayerNorm(config.d_model)
|
|
894
1014
|
|
|
895
|
-
self.gradient_checkpointing = False
|
|
896
1015
|
# Initialize weights and apply final processing
|
|
897
1016
|
self.post_init()
|
|
898
1017
|
|
|
1018
|
+
@check_model_inputs()
|
|
899
1019
|
def forward(
|
|
900
1020
|
self,
|
|
901
1021
|
inputs_embeds=None,
|
|
902
1022
|
attention_mask=None,
|
|
903
1023
|
encoder_hidden_states=None,
|
|
904
1024
|
encoder_attention_mask=None,
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
return_dict=None,
|
|
910
|
-
**kwargs,
|
|
911
|
-
):
|
|
1025
|
+
spatial_position_embeddings=None,
|
|
1026
|
+
object_queries_position_embeddings=None,
|
|
1027
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1028
|
+
) -> DetrDecoderOutput:
|
|
912
1029
|
r"""
|
|
913
1030
|
Args:
|
|
914
1031
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
@@ -931,108 +1048,62 @@ class DetrDecoder(DetrPreTrainedModel):
|
|
|
931
1048
|
- 1 for pixels that are real (i.e. **not masked**),
|
|
932
1049
|
- 0 for pixels that are padding (i.e. **masked**).
|
|
933
1050
|
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
output_attentions (`bool`, *optional*):
|
|
940
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
941
|
-
returned tensors for more detail.
|
|
942
|
-
output_hidden_states (`bool`, *optional*):
|
|
943
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
944
|
-
for more detail.
|
|
945
|
-
return_dict (`bool`, *optional*):
|
|
946
|
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
1051
|
+
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1052
|
+
Spatial position embeddings (2D positional encodings from encoder) that are added to the keys in each cross-attention layer.
|
|
1053
|
+
object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1054
|
+
Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer.
|
|
947
1055
|
"""
|
|
948
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
949
|
-
output_hidden_states = (
|
|
950
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
951
|
-
)
|
|
952
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
953
1056
|
|
|
954
1057
|
if inputs_embeds is not None:
|
|
955
1058
|
hidden_states = inputs_embeds
|
|
956
|
-
input_shape = inputs_embeds.size()[:-1]
|
|
957
|
-
|
|
958
|
-
combined_attention_mask = None
|
|
959
1059
|
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
1060
|
+
# expand decoder attention mask (for self-attention on object queries)
|
|
1061
|
+
if attention_mask is not None:
|
|
1062
|
+
# [batch_size, num_queries] -> [batch_size, 1, num_queries, num_queries]
|
|
1063
|
+
attention_mask = create_bidirectional_mask(
|
|
1064
|
+
config=self.config,
|
|
1065
|
+
input_embeds=inputs_embeds,
|
|
1066
|
+
attention_mask=attention_mask,
|
|
964
1067
|
)
|
|
965
1068
|
|
|
966
|
-
# expand encoder attention mask
|
|
1069
|
+
# expand encoder attention mask (for cross-attention on encoder outputs)
|
|
967
1070
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
|
968
1071
|
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
969
|
-
encoder_attention_mask =
|
|
970
|
-
|
|
1072
|
+
encoder_attention_mask = create_bidirectional_mask(
|
|
1073
|
+
config=self.config,
|
|
1074
|
+
input_embeds=inputs_embeds,
|
|
1075
|
+
attention_mask=encoder_attention_mask,
|
|
1076
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
971
1077
|
)
|
|
972
1078
|
|
|
973
1079
|
# optional intermediate hidden states
|
|
974
1080
|
intermediate = () if self.config.auxiliary_loss else None
|
|
975
1081
|
|
|
976
1082
|
# decoder layers
|
|
977
|
-
all_hidden_states = () if output_hidden_states else None
|
|
978
|
-
all_self_attns = () if output_attentions else None
|
|
979
|
-
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
980
1083
|
|
|
981
1084
|
for idx, decoder_layer in enumerate(self.layers):
|
|
982
|
-
|
|
983
|
-
if output_hidden_states:
|
|
984
|
-
all_hidden_states += (hidden_states,)
|
|
985
|
-
if self.training:
|
|
986
|
-
dropout_probability = torch.rand([])
|
|
987
|
-
if dropout_probability < self.layerdrop:
|
|
988
|
-
continue
|
|
989
|
-
|
|
990
|
-
layer_outputs = decoder_layer(
|
|
1085
|
+
hidden_states = decoder_layer(
|
|
991
1086
|
hidden_states,
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
1087
|
+
attention_mask,
|
|
1088
|
+
spatial_position_embeddings,
|
|
1089
|
+
object_queries_position_embeddings,
|
|
995
1090
|
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
|
996
1091
|
encoder_attention_mask=encoder_attention_mask,
|
|
997
|
-
|
|
1092
|
+
**kwargs,
|
|
998
1093
|
)
|
|
999
1094
|
|
|
1000
|
-
hidden_states = layer_outputs[0]
|
|
1001
|
-
|
|
1002
1095
|
if self.config.auxiliary_loss:
|
|
1003
1096
|
hidden_states = self.layernorm(hidden_states)
|
|
1004
1097
|
intermediate += (hidden_states,)
|
|
1005
1098
|
|
|
1006
|
-
if output_attentions:
|
|
1007
|
-
all_self_attns += (layer_outputs[1],)
|
|
1008
|
-
|
|
1009
|
-
if encoder_hidden_states is not None:
|
|
1010
|
-
all_cross_attentions += (layer_outputs[2],)
|
|
1011
|
-
|
|
1012
1099
|
# finally, apply layernorm
|
|
1013
1100
|
hidden_states = self.layernorm(hidden_states)
|
|
1014
1101
|
|
|
1015
|
-
# add hidden states from the last decoder layer
|
|
1016
|
-
if output_hidden_states:
|
|
1017
|
-
all_hidden_states += (hidden_states,)
|
|
1018
|
-
|
|
1019
1102
|
# stack intermediate decoder activations
|
|
1020
1103
|
if self.config.auxiliary_loss:
|
|
1021
1104
|
intermediate = torch.stack(intermediate)
|
|
1022
1105
|
|
|
1023
|
-
|
|
1024
|
-
return tuple(
|
|
1025
|
-
v
|
|
1026
|
-
for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate]
|
|
1027
|
-
if v is not None
|
|
1028
|
-
)
|
|
1029
|
-
return DetrDecoderOutput(
|
|
1030
|
-
last_hidden_state=hidden_states,
|
|
1031
|
-
hidden_states=all_hidden_states,
|
|
1032
|
-
attentions=all_self_attns,
|
|
1033
|
-
cross_attentions=all_cross_attentions,
|
|
1034
|
-
intermediate_hidden_states=intermediate,
|
|
1035
|
-
)
|
|
1106
|
+
return DetrDecoderOutput(last_hidden_state=hidden_states, intermediate_hidden_states=intermediate)
|
|
1036
1107
|
|
|
1037
1108
|
|
|
1038
1109
|
@auto_docstring(
|
|
@@ -1045,15 +1116,16 @@ class DetrModel(DetrPreTrainedModel):
|
|
|
1045
1116
|
def __init__(self, config: DetrConfig):
|
|
1046
1117
|
super().__init__(config)
|
|
1047
1118
|
|
|
1048
|
-
|
|
1049
|
-
backbone = DetrConvEncoder(config)
|
|
1050
|
-
object_queries = build_position_encoding(config)
|
|
1051
|
-
self.backbone = DetrConvModel(backbone, object_queries)
|
|
1052
|
-
|
|
1053
|
-
# Create projection layer
|
|
1054
|
-
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
|
1119
|
+
self.backbone = DetrConvEncoder(config)
|
|
1055
1120
|
|
|
1121
|
+
if config.position_embedding_type == "sine":
|
|
1122
|
+
self.position_embedding = DetrSinePositionEmbedding(config.d_model // 2, normalize=True)
|
|
1123
|
+
elif config.position_embedding_type == "learned":
|
|
1124
|
+
self.position_embedding = DetrLearnedPositionEmbedding(config.d_model // 2)
|
|
1125
|
+
else:
|
|
1126
|
+
raise ValueError(f"Not supported {config.position_embedding_type}")
|
|
1056
1127
|
self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
|
|
1128
|
+
self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
|
1057
1129
|
|
|
1058
1130
|
self.encoder = DetrEncoder(config)
|
|
1059
1131
|
self.decoder = DetrDecoder(config)
|
|
@@ -1062,46 +1134,49 @@ class DetrModel(DetrPreTrainedModel):
|
|
|
1062
1134
|
self.post_init()
|
|
1063
1135
|
|
|
1064
1136
|
def freeze_backbone(self):
|
|
1065
|
-
for
|
|
1137
|
+
for _, param in self.backbone.model.named_parameters():
|
|
1066
1138
|
param.requires_grad_(False)
|
|
1067
1139
|
|
|
1068
1140
|
def unfreeze_backbone(self):
|
|
1069
|
-
for
|
|
1141
|
+
for _, param in self.backbone.model.named_parameters():
|
|
1070
1142
|
param.requires_grad_(True)
|
|
1071
1143
|
|
|
1072
1144
|
@auto_docstring
|
|
1145
|
+
@can_return_tuple
|
|
1073
1146
|
def forward(
|
|
1074
1147
|
self,
|
|
1075
|
-
pixel_values: torch.FloatTensor,
|
|
1148
|
+
pixel_values: torch.FloatTensor | None = None,
|
|
1076
1149
|
pixel_mask: torch.LongTensor | None = None,
|
|
1077
1150
|
decoder_attention_mask: torch.FloatTensor | None = None,
|
|
1078
1151
|
encoder_outputs: torch.FloatTensor | None = None,
|
|
1079
1152
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1080
1153
|
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1081
|
-
|
|
1082
|
-
output_hidden_states: bool | None = None,
|
|
1083
|
-
return_dict: bool | None = None,
|
|
1084
|
-
**kwargs,
|
|
1154
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1085
1155
|
) -> tuple[torch.FloatTensor] | DetrModelOutput:
|
|
1086
1156
|
r"""
|
|
1087
1157
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
1088
|
-
|
|
1158
|
+
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
|
|
1159
|
+
|
|
1160
|
+
- 1 for queries that are **not masked**,
|
|
1161
|
+
- 0 for queries that are **masked**.
|
|
1089
1162
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1090
1163
|
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
1091
|
-
can choose to directly pass a flattened representation of an image.
|
|
1164
|
+
can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
|
|
1092
1165
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1093
1166
|
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1094
|
-
embedded representation.
|
|
1167
|
+
embedded representation. Useful for tasks that require custom query initialization.
|
|
1095
1168
|
|
|
1096
1169
|
Examples:
|
|
1097
1170
|
|
|
1098
1171
|
```python
|
|
1099
1172
|
>>> from transformers import AutoImageProcessor, DetrModel
|
|
1100
1173
|
>>> from PIL import Image
|
|
1101
|
-
>>> import
|
|
1174
|
+
>>> import httpx
|
|
1175
|
+
>>> from io import BytesIO
|
|
1102
1176
|
|
|
1103
1177
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
1104
|
-
>>>
|
|
1178
|
+
>>> with httpx.stream("GET", url) as response:
|
|
1179
|
+
... image = Image.open(BytesIO(response.read()))
|
|
1105
1180
|
|
|
1106
1181
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
|
1107
1182
|
>>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
|
|
@@ -1118,79 +1193,77 @@ class DetrModel(DetrPreTrainedModel):
|
|
|
1118
1193
|
>>> list(last_hidden_states.shape)
|
|
1119
1194
|
[1, 100, 256]
|
|
1120
1195
|
```"""
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1196
|
+
if pixel_values is None and inputs_embeds is None:
|
|
1197
|
+
raise ValueError("You have to specify either pixel_values or inputs_embeds")
|
|
1198
|
+
|
|
1199
|
+
if inputs_embeds is None:
|
|
1200
|
+
batch_size, num_channels, height, width = pixel_values.shape
|
|
1201
|
+
device = pixel_values.device
|
|
1202
|
+
|
|
1203
|
+
if pixel_mask is None:
|
|
1204
|
+
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
|
1205
|
+
vision_features = self.backbone(pixel_values, pixel_mask)
|
|
1206
|
+
feature_map, mask = vision_features[-1]
|
|
1207
|
+
|
|
1208
|
+
# Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
|
|
1209
|
+
# Position embeddings are already flattened to (batch_size, sequence_length, hidden_size) format
|
|
1210
|
+
projected_feature_map = self.input_projection(feature_map)
|
|
1211
|
+
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
|
1212
|
+
spatial_position_embeddings = self.position_embedding(
|
|
1213
|
+
shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
|
|
1214
|
+
)
|
|
1215
|
+
flattened_mask = mask.flatten(1)
|
|
1216
|
+
else:
|
|
1217
|
+
batch_size = inputs_embeds.shape[0]
|
|
1218
|
+
device = inputs_embeds.device
|
|
1219
|
+
flattened_features = inputs_embeds
|
|
1220
|
+
# When using inputs_embeds, we need to infer spatial dimensions for position embeddings
|
|
1221
|
+
# Assume square feature map
|
|
1222
|
+
seq_len = inputs_embeds.shape[1]
|
|
1223
|
+
feat_dim = int(seq_len**0.5)
|
|
1224
|
+
# Create position embeddings for the inferred spatial size
|
|
1225
|
+
spatial_position_embeddings = self.position_embedding(
|
|
1226
|
+
shape=torch.Size([batch_size, self.config.d_model, feat_dim, feat_dim]),
|
|
1227
|
+
device=device,
|
|
1228
|
+
dtype=inputs_embeds.dtype,
|
|
1229
|
+
)
|
|
1230
|
+
# If a pixel_mask is provided with inputs_embeds, interpolate it to feat_dim, then flatten.
|
|
1231
|
+
if pixel_mask is not None:
|
|
1232
|
+
mask = nn.functional.interpolate(pixel_mask[None].float(), size=(feat_dim, feat_dim)).to(torch.bool)[0]
|
|
1233
|
+
flattened_mask = mask.flatten(1)
|
|
1234
|
+
else:
|
|
1235
|
+
# If no mask provided, assume all positions are valid
|
|
1236
|
+
flattened_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long)
|
|
1153
1237
|
|
|
1154
|
-
# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
|
|
1155
|
-
# flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
|
|
1156
|
-
# flattened_mask is a Tensor of shape (batch_size, height*width)
|
|
1157
1238
|
if encoder_outputs is None:
|
|
1158
1239
|
encoder_outputs = self.encoder(
|
|
1159
1240
|
inputs_embeds=flattened_features,
|
|
1160
1241
|
attention_mask=flattened_mask,
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
output_hidden_states=output_hidden_states,
|
|
1164
|
-
return_dict=return_dict,
|
|
1165
|
-
)
|
|
1166
|
-
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
|
1167
|
-
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
|
1168
|
-
encoder_outputs = BaseModelOutput(
|
|
1169
|
-
last_hidden_state=encoder_outputs[0],
|
|
1170
|
-
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1171
|
-
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1242
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1243
|
+
**kwargs,
|
|
1172
1244
|
)
|
|
1173
1245
|
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1246
|
+
object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(
|
|
1247
|
+
batch_size, 1, 1
|
|
1248
|
+
)
|
|
1249
|
+
|
|
1250
|
+
# Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
|
|
1251
|
+
if decoder_inputs_embeds is not None:
|
|
1252
|
+
queries = decoder_inputs_embeds
|
|
1253
|
+
else:
|
|
1254
|
+
queries = torch.zeros_like(object_queries_position_embeddings)
|
|
1177
1255
|
|
|
1178
1256
|
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
|
|
1179
1257
|
decoder_outputs = self.decoder(
|
|
1180
1258
|
inputs_embeds=queries,
|
|
1181
|
-
attention_mask=
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
encoder_hidden_states=encoder_outputs
|
|
1259
|
+
attention_mask=decoder_attention_mask,
|
|
1260
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1261
|
+
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
1262
|
+
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
1185
1263
|
encoder_attention_mask=flattened_mask,
|
|
1186
|
-
|
|
1187
|
-
output_hidden_states=output_hidden_states,
|
|
1188
|
-
return_dict=return_dict,
|
|
1264
|
+
**kwargs,
|
|
1189
1265
|
)
|
|
1190
1266
|
|
|
1191
|
-
if not return_dict:
|
|
1192
|
-
return decoder_outputs + encoder_outputs
|
|
1193
|
-
|
|
1194
1267
|
return DetrModelOutput(
|
|
1195
1268
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
1196
1269
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
@@ -1203,14 +1276,11 @@ class DetrModel(DetrPreTrainedModel):
|
|
|
1203
1276
|
)
|
|
1204
1277
|
|
|
1205
1278
|
|
|
1206
|
-
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
1207
1279
|
class DetrMLPPredictionHead(nn.Module):
|
|
1208
1280
|
"""
|
|
1209
1281
|
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
|
1210
1282
|
height and width of a bounding box w.r.t. an image.
|
|
1211
1283
|
|
|
1212
|
-
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
1213
|
-
|
|
1214
1284
|
"""
|
|
1215
1285
|
|
|
1216
1286
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
@@ -1250,6 +1320,7 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1250
1320
|
self.post_init()
|
|
1251
1321
|
|
|
1252
1322
|
@auto_docstring
|
|
1323
|
+
@can_return_tuple
|
|
1253
1324
|
def forward(
|
|
1254
1325
|
self,
|
|
1255
1326
|
pixel_values: torch.FloatTensor,
|
|
@@ -1259,20 +1330,20 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1259
1330
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1260
1331
|
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1261
1332
|
labels: list[dict] | None = None,
|
|
1262
|
-
|
|
1263
|
-
output_hidden_states: bool | None = None,
|
|
1264
|
-
return_dict: bool | None = None,
|
|
1265
|
-
**kwargs,
|
|
1333
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1266
1334
|
) -> tuple[torch.FloatTensor] | DetrObjectDetectionOutput:
|
|
1267
1335
|
r"""
|
|
1268
1336
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
1269
|
-
|
|
1337
|
+
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
|
|
1338
|
+
|
|
1339
|
+
- 1 for queries that are **not masked**,
|
|
1340
|
+
- 0 for queries that are **masked**.
|
|
1270
1341
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1271
1342
|
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
1272
|
-
can choose to directly pass a flattened representation of an image.
|
|
1343
|
+
can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
|
|
1273
1344
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1274
1345
|
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1275
|
-
embedded representation.
|
|
1346
|
+
embedded representation. Useful for tasks that require custom query initialization.
|
|
1276
1347
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1277
1348
|
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
|
1278
1349
|
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
|
@@ -1285,10 +1356,12 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1285
1356
|
>>> from transformers import AutoImageProcessor, DetrForObjectDetection
|
|
1286
1357
|
>>> import torch
|
|
1287
1358
|
>>> from PIL import Image
|
|
1288
|
-
>>> import
|
|
1359
|
+
>>> import httpx
|
|
1360
|
+
>>> from io import BytesIO
|
|
1289
1361
|
|
|
1290
1362
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
1291
|
-
>>>
|
|
1363
|
+
>>> with httpx.stream("GET", url) as response:
|
|
1364
|
+
... image = Image.open(BytesIO(response.read()))
|
|
1292
1365
|
|
|
1293
1366
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
|
1294
1367
|
>>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
|
@@ -1314,7 +1387,6 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1314
1387
|
Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
|
|
1315
1388
|
Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
|
|
1316
1389
|
```"""
|
|
1317
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1318
1390
|
|
|
1319
1391
|
# First, sent images through DETR base model to obtain encoder + decoder outputs
|
|
1320
1392
|
outputs = self.model(
|
|
@@ -1324,9 +1396,7 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1324
1396
|
encoder_outputs=encoder_outputs,
|
|
1325
1397
|
inputs_embeds=inputs_embeds,
|
|
1326
1398
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
1327
|
-
|
|
1328
|
-
output_hidden_states=output_hidden_states,
|
|
1329
|
-
return_dict=return_dict,
|
|
1399
|
+
**kwargs,
|
|
1330
1400
|
)
|
|
1331
1401
|
|
|
1332
1402
|
sequence_output = outputs[0]
|
|
@@ -1339,20 +1409,13 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1339
1409
|
if labels is not None:
|
|
1340
1410
|
outputs_class, outputs_coord = None, None
|
|
1341
1411
|
if self.config.auxiliary_loss:
|
|
1342
|
-
intermediate = outputs.intermediate_hidden_states
|
|
1412
|
+
intermediate = outputs.intermediate_hidden_states
|
|
1343
1413
|
outputs_class = self.class_labels_classifier(intermediate)
|
|
1344
1414
|
outputs_coord = self.bbox_predictor(intermediate).sigmoid()
|
|
1345
1415
|
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
|
1346
1416
|
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
|
1347
1417
|
)
|
|
1348
1418
|
|
|
1349
|
-
if not return_dict:
|
|
1350
|
-
if auxiliary_outputs is not None:
|
|
1351
|
-
output = (logits, pred_boxes) + auxiliary_outputs + outputs
|
|
1352
|
-
else:
|
|
1353
|
-
output = (logits, pred_boxes) + outputs
|
|
1354
|
-
return ((loss, loss_dict) + output) if loss is not None else output
|
|
1355
|
-
|
|
1356
1419
|
return DetrObjectDetectionOutput(
|
|
1357
1420
|
loss=loss,
|
|
1358
1421
|
loss_dict=loss_dict,
|
|
@@ -1376,6 +1439,26 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1376
1439
|
"""
|
|
1377
1440
|
)
|
|
1378
1441
|
class DetrForSegmentation(DetrPreTrainedModel):
|
|
1442
|
+
_checkpoint_conversion_mapping = {
|
|
1443
|
+
"bbox_attention.q_linear": "bbox_attention.q_proj",
|
|
1444
|
+
"bbox_attention.k_linear": "bbox_attention.k_proj",
|
|
1445
|
+
# Mask head refactor
|
|
1446
|
+
"mask_head.lay1": "mask_head.conv1.conv",
|
|
1447
|
+
"mask_head.gn1": "mask_head.conv1.norm",
|
|
1448
|
+
"mask_head.lay2": "mask_head.conv2.conv",
|
|
1449
|
+
"mask_head.gn2": "mask_head.conv2.norm",
|
|
1450
|
+
"mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter",
|
|
1451
|
+
"mask_head.lay3": "mask_head.fpn_stages.0.refine.conv",
|
|
1452
|
+
"mask_head.gn3": "mask_head.fpn_stages.0.refine.norm",
|
|
1453
|
+
"mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter",
|
|
1454
|
+
"mask_head.lay4": "mask_head.fpn_stages.1.refine.conv",
|
|
1455
|
+
"mask_head.gn4": "mask_head.fpn_stages.1.refine.norm",
|
|
1456
|
+
"mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter",
|
|
1457
|
+
"mask_head.lay5": "mask_head.fpn_stages.2.refine.conv",
|
|
1458
|
+
"mask_head.gn5": "mask_head.fpn_stages.2.refine.norm",
|
|
1459
|
+
"mask_head.out_lay": "mask_head.output_conv",
|
|
1460
|
+
}
|
|
1461
|
+
|
|
1379
1462
|
def __init__(self, config: DetrConfig):
|
|
1380
1463
|
super().__init__(config)
|
|
1381
1464
|
|
|
@@ -1384,19 +1467,21 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1384
1467
|
|
|
1385
1468
|
# segmentation head
|
|
1386
1469
|
hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
|
|
1387
|
-
intermediate_channel_sizes = self.detr.model.backbone.
|
|
1470
|
+
intermediate_channel_sizes = self.detr.model.backbone.intermediate_channel_sizes
|
|
1388
1471
|
|
|
1389
1472
|
self.mask_head = DetrMaskHeadSmallConv(
|
|
1390
|
-
hidden_size + number_of_heads,
|
|
1473
|
+
input_channels=hidden_size + number_of_heads,
|
|
1474
|
+
fpn_channels=intermediate_channel_sizes[::-1][-3:],
|
|
1475
|
+
hidden_size=hidden_size,
|
|
1476
|
+
activation_function=config.activation_function,
|
|
1391
1477
|
)
|
|
1392
1478
|
|
|
1393
|
-
self.bbox_attention = DetrMHAttentionMap(
|
|
1394
|
-
hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
|
|
1395
|
-
)
|
|
1479
|
+
self.bbox_attention = DetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0)
|
|
1396
1480
|
# Initialize weights and apply final processing
|
|
1397
1481
|
self.post_init()
|
|
1398
1482
|
|
|
1399
1483
|
@auto_docstring
|
|
1484
|
+
@can_return_tuple
|
|
1400
1485
|
def forward(
|
|
1401
1486
|
self,
|
|
1402
1487
|
pixel_values: torch.FloatTensor,
|
|
@@ -1406,20 +1491,20 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1406
1491
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1407
1492
|
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1408
1493
|
labels: list[dict] | None = None,
|
|
1409
|
-
|
|
1410
|
-
output_hidden_states: bool | None = None,
|
|
1411
|
-
return_dict: bool | None = None,
|
|
1412
|
-
**kwargs,
|
|
1494
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1413
1495
|
) -> tuple[torch.FloatTensor] | DetrSegmentationOutput:
|
|
1414
1496
|
r"""
|
|
1415
1497
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
1416
|
-
|
|
1498
|
+
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
|
|
1499
|
+
|
|
1500
|
+
- 1 for queries that are **not masked**,
|
|
1501
|
+
- 0 for queries that are **masked**.
|
|
1417
1502
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1418
|
-
|
|
1419
|
-
|
|
1503
|
+
Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires
|
|
1504
|
+
multi-scale features from the backbone that are not available when bypassing it with inputs_embeds.
|
|
1420
1505
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1421
1506
|
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1422
|
-
embedded representation.
|
|
1507
|
+
embedded representation. Useful for tasks that require custom query initialization.
|
|
1423
1508
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1424
1509
|
Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
|
|
1425
1510
|
dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
|
|
@@ -1432,7 +1517,8 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1432
1517
|
|
|
1433
1518
|
```python
|
|
1434
1519
|
>>> import io
|
|
1435
|
-
>>> import
|
|
1520
|
+
>>> import httpx
|
|
1521
|
+
>>> from io import BytesIO
|
|
1436
1522
|
>>> from PIL import Image
|
|
1437
1523
|
>>> import torch
|
|
1438
1524
|
>>> import numpy
|
|
@@ -1441,7 +1527,8 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1441
1527
|
>>> from transformers.image_transforms import rgb_to_id
|
|
1442
1528
|
|
|
1443
1529
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
1444
|
-
>>>
|
|
1530
|
+
>>> with httpx.stream("GET", url) as response:
|
|
1531
|
+
... image = Image.open(BytesIO(response.read()))
|
|
1445
1532
|
|
|
1446
1533
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
|
|
1447
1534
|
>>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
|
|
@@ -1466,83 +1553,77 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1466
1553
|
5
|
|
1467
1554
|
```"""
|
|
1468
1555
|
|
|
1469
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1470
|
-
|
|
1471
1556
|
batch_size, num_channels, height, width = pixel_values.shape
|
|
1472
1557
|
device = pixel_values.device
|
|
1473
1558
|
|
|
1474
1559
|
if pixel_mask is None:
|
|
1475
1560
|
pixel_mask = torch.ones((batch_size, height, width), device=device)
|
|
1476
1561
|
|
|
1477
|
-
|
|
1478
|
-
|
|
1562
|
+
vision_features = self.detr.model.backbone(pixel_values, pixel_mask)
|
|
1563
|
+
feature_map, mask = vision_features[-1]
|
|
1479
1564
|
|
|
1480
|
-
#
|
|
1481
|
-
feature_map, mask = features[-1]
|
|
1482
|
-
batch_size, num_channels, height, width = feature_map.shape
|
|
1565
|
+
# Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
|
|
1483
1566
|
projected_feature_map = self.detr.model.input_projection(feature_map)
|
|
1484
|
-
|
|
1485
|
-
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
|
1486
|
-
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
|
1487
1567
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
|
1488
|
-
|
|
1489
|
-
|
|
1568
|
+
spatial_position_embeddings = self.detr.model.position_embedding(
|
|
1569
|
+
shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
|
|
1570
|
+
)
|
|
1490
1571
|
flattened_mask = mask.flatten(1)
|
|
1491
1572
|
|
|
1492
|
-
# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
|
|
1493
|
-
# flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
|
|
1494
|
-
# flattened_mask is a Tensor of shape (batch_size, height*width)
|
|
1495
1573
|
if encoder_outputs is None:
|
|
1496
1574
|
encoder_outputs = self.detr.model.encoder(
|
|
1497
1575
|
inputs_embeds=flattened_features,
|
|
1498
1576
|
attention_mask=flattened_mask,
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
output_hidden_states=output_hidden_states,
|
|
1502
|
-
return_dict=return_dict,
|
|
1503
|
-
)
|
|
1504
|
-
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
|
1505
|
-
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
|
1506
|
-
encoder_outputs = BaseModelOutput(
|
|
1507
|
-
last_hidden_state=encoder_outputs[0],
|
|
1508
|
-
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1509
|
-
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1577
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1578
|
+
**kwargs,
|
|
1510
1579
|
)
|
|
1511
1580
|
|
|
1512
|
-
|
|
1513
|
-
query_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
|
|
1581
|
+
object_queries_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
|
|
1514
1582
|
batch_size, 1, 1
|
|
1515
1583
|
)
|
|
1516
|
-
queries = torch.zeros_like(query_position_embeddings)
|
|
1517
1584
|
|
|
1518
|
-
#
|
|
1585
|
+
# Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
|
|
1586
|
+
if decoder_inputs_embeds is not None:
|
|
1587
|
+
queries = decoder_inputs_embeds
|
|
1588
|
+
else:
|
|
1589
|
+
queries = torch.zeros_like(object_queries_position_embeddings)
|
|
1590
|
+
|
|
1519
1591
|
decoder_outputs = self.detr.model.decoder(
|
|
1520
1592
|
inputs_embeds=queries,
|
|
1521
|
-
attention_mask=
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
encoder_hidden_states=encoder_outputs
|
|
1593
|
+
attention_mask=decoder_attention_mask,
|
|
1594
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1595
|
+
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
1596
|
+
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
1525
1597
|
encoder_attention_mask=flattened_mask,
|
|
1526
|
-
|
|
1527
|
-
output_hidden_states=output_hidden_states,
|
|
1528
|
-
return_dict=return_dict,
|
|
1598
|
+
**kwargs,
|
|
1529
1599
|
)
|
|
1530
1600
|
|
|
1531
1601
|
sequence_output = decoder_outputs[0]
|
|
1532
1602
|
|
|
1533
|
-
# Sixth, compute logits, pred_boxes and pred_masks
|
|
1534
1603
|
logits = self.detr.class_labels_classifier(sequence_output)
|
|
1535
1604
|
pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()
|
|
1536
1605
|
|
|
1537
|
-
|
|
1538
|
-
|
|
1606
|
+
height, width = feature_map.shape[-2:]
|
|
1607
|
+
memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view(
|
|
1608
|
+
batch_size, self.config.d_model, height, width
|
|
1609
|
+
)
|
|
1610
|
+
attention_mask = flattened_mask.view(batch_size, height, width)
|
|
1539
1611
|
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1612
|
+
if attention_mask is not None:
|
|
1613
|
+
min_dtype = torch.finfo(memory.dtype).min
|
|
1614
|
+
attention_mask = torch.where(
|
|
1615
|
+
attention_mask.unsqueeze(1).unsqueeze(1),
|
|
1616
|
+
torch.tensor(0.0, device=memory.device, dtype=memory.dtype),
|
|
1617
|
+
min_dtype,
|
|
1618
|
+
)
|
|
1544
1619
|
|
|
1545
|
-
|
|
1620
|
+
bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=attention_mask)
|
|
1621
|
+
|
|
1622
|
+
seg_masks = self.mask_head(
|
|
1623
|
+
features=projected_feature_map,
|
|
1624
|
+
attention_masks=bbox_mask,
|
|
1625
|
+
fpn_features=[vision_features[2][0], vision_features[1][0], vision_features[0][0]],
|
|
1626
|
+
)
|
|
1546
1627
|
|
|
1547
1628
|
pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
|
|
1548
1629
|
|
|
@@ -1550,20 +1631,13 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1550
1631
|
if labels is not None:
|
|
1551
1632
|
outputs_class, outputs_coord = None, None
|
|
1552
1633
|
if self.config.auxiliary_loss:
|
|
1553
|
-
intermediate = decoder_outputs.intermediate_hidden_states
|
|
1634
|
+
intermediate = decoder_outputs.intermediate_hidden_states
|
|
1554
1635
|
outputs_class = self.detr.class_labels_classifier(intermediate)
|
|
1555
1636
|
outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
|
|
1556
1637
|
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
|
1557
1638
|
logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
|
|
1558
1639
|
)
|
|
1559
1640
|
|
|
1560
|
-
if not return_dict:
|
|
1561
|
-
if auxiliary_outputs is not None:
|
|
1562
|
-
output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
|
|
1563
|
-
else:
|
|
1564
|
-
output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
|
|
1565
|
-
return ((loss, loss_dict) + output) if loss is not None else output
|
|
1566
|
-
|
|
1567
1641
|
return DetrSegmentationOutput(
|
|
1568
1642
|
loss=loss,
|
|
1569
1643
|
loss_dict=loss_dict,
|
|
@@ -1581,119 +1655,6 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1581
1655
|
)
|
|
1582
1656
|
|
|
1583
1657
|
|
|
1584
|
-
def _expand(tensor, length: int):
|
|
1585
|
-
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
# taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py
|
|
1589
|
-
class DetrMaskHeadSmallConv(nn.Module):
|
|
1590
|
-
"""
|
|
1591
|
-
Simple convolutional head, using group norm. Upsampling is done using a FPN approach
|
|
1592
|
-
"""
|
|
1593
|
-
|
|
1594
|
-
def __init__(self, dim, fpn_dims, context_dim):
|
|
1595
|
-
super().__init__()
|
|
1596
|
-
|
|
1597
|
-
if dim % 8 != 0:
|
|
1598
|
-
raise ValueError(
|
|
1599
|
-
"The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
|
|
1600
|
-
" GroupNorm is set to 8"
|
|
1601
|
-
)
|
|
1602
|
-
|
|
1603
|
-
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
|
|
1604
|
-
|
|
1605
|
-
self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
|
|
1606
|
-
self.gn1 = nn.GroupNorm(8, dim)
|
|
1607
|
-
self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
|
|
1608
|
-
self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1])
|
|
1609
|
-
self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
|
|
1610
|
-
self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2])
|
|
1611
|
-
self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
|
|
1612
|
-
self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3])
|
|
1613
|
-
self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
|
|
1614
|
-
self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4])
|
|
1615
|
-
self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
|
|
1616
|
-
|
|
1617
|
-
self.dim = dim
|
|
1618
|
-
|
|
1619
|
-
self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
|
|
1620
|
-
self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
|
|
1621
|
-
self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
|
|
1622
|
-
|
|
1623
|
-
for m in self.modules():
|
|
1624
|
-
if isinstance(m, nn.Conv2d):
|
|
1625
|
-
init.kaiming_uniform_(m.weight, a=1)
|
|
1626
|
-
init.constant_(m.bias, 0)
|
|
1627
|
-
|
|
1628
|
-
def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]):
|
|
1629
|
-
# here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with
|
|
1630
|
-
# the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
|
|
1631
|
-
# We expand the projected feature map to match the number of heads.
|
|
1632
|
-
x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
|
|
1633
|
-
|
|
1634
|
-
x = self.lay1(x)
|
|
1635
|
-
x = self.gn1(x)
|
|
1636
|
-
x = nn.functional.relu(x)
|
|
1637
|
-
x = self.lay2(x)
|
|
1638
|
-
x = self.gn2(x)
|
|
1639
|
-
x = nn.functional.relu(x)
|
|
1640
|
-
|
|
1641
|
-
cur_fpn = self.adapter1(fpns[0])
|
|
1642
|
-
if cur_fpn.size(0) != x.size(0):
|
|
1643
|
-
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
|
1644
|
-
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
1645
|
-
x = self.lay3(x)
|
|
1646
|
-
x = self.gn3(x)
|
|
1647
|
-
x = nn.functional.relu(x)
|
|
1648
|
-
|
|
1649
|
-
cur_fpn = self.adapter2(fpns[1])
|
|
1650
|
-
if cur_fpn.size(0) != x.size(0):
|
|
1651
|
-
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
|
1652
|
-
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
1653
|
-
x = self.lay4(x)
|
|
1654
|
-
x = self.gn4(x)
|
|
1655
|
-
x = nn.functional.relu(x)
|
|
1656
|
-
|
|
1657
|
-
cur_fpn = self.adapter3(fpns[2])
|
|
1658
|
-
if cur_fpn.size(0) != x.size(0):
|
|
1659
|
-
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
|
1660
|
-
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
1661
|
-
x = self.lay5(x)
|
|
1662
|
-
x = self.gn5(x)
|
|
1663
|
-
x = nn.functional.relu(x)
|
|
1664
|
-
|
|
1665
|
-
x = self.out_lay(x)
|
|
1666
|
-
return x
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
class DetrMHAttentionMap(nn.Module):
|
|
1670
|
-
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
|
1671
|
-
|
|
1672
|
-
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
|
|
1673
|
-
super().__init__()
|
|
1674
|
-
self.num_heads = num_heads
|
|
1675
|
-
self.hidden_dim = hidden_dim
|
|
1676
|
-
self.dropout = nn.Dropout(dropout)
|
|
1677
|
-
|
|
1678
|
-
self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
|
1679
|
-
self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
|
1680
|
-
|
|
1681
|
-
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
|
|
1682
|
-
|
|
1683
|
-
def forward(self, q, k, mask: Tensor | None = None):
|
|
1684
|
-
q = self.q_linear(q)
|
|
1685
|
-
k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
|
|
1686
|
-
queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
|
|
1687
|
-
keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
|
|
1688
|
-
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
|
|
1689
|
-
|
|
1690
|
-
if mask is not None:
|
|
1691
|
-
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
|
1692
|
-
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
|
|
1693
|
-
weights = self.dropout(weights)
|
|
1694
|
-
return weights
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
1658
|
__all__ = [
|
|
1698
1659
|
"DetrForObjectDetection",
|
|
1699
1660
|
"DetrForSegmentation",
|