transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
2
|
+
# This file was automatically generated from src/transformers/models/deformable_detr/modular_deformable_detr.py.
|
|
3
|
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
4
|
+
# the file from the modular. If any change should be done, please apply the change to the
|
|
5
|
+
# modular_deformable_detr.py file directly. One of our CI enforces this.
|
|
6
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
1
7
|
# Copyright 2022 SenseTime and The HuggingFace Inc. team. All rights reserved.
|
|
2
8
|
#
|
|
3
9
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -11,128 +17,54 @@
|
|
|
11
17
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
18
|
# See the License for the specific language governing permissions and
|
|
13
19
|
# limitations under the License.
|
|
14
|
-
"""PyTorch Deformable DETR model."""
|
|
15
|
-
|
|
16
20
|
import math
|
|
17
21
|
import warnings
|
|
22
|
+
from collections.abc import Callable
|
|
18
23
|
from dataclasses import dataclass
|
|
19
|
-
from typing import Any
|
|
20
24
|
|
|
21
25
|
import torch
|
|
26
|
+
import torch.nn as nn
|
|
22
27
|
import torch.nn.functional as F
|
|
23
|
-
from torch import Tensor
|
|
28
|
+
from torch import Tensor
|
|
24
29
|
|
|
25
30
|
from ... import initialization as init
|
|
26
31
|
from ...activations import ACT2FN
|
|
32
|
+
from ...backbone_utils import load_backbone
|
|
27
33
|
from ...integrations import use_kernel_forward_from_hub
|
|
28
|
-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|
29
34
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
30
|
-
from ...modeling_outputs import BaseModelOutput
|
|
31
|
-
from ...modeling_utils import PreTrainedModel
|
|
32
|
-
from ...
|
|
33
|
-
from ...
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
is_timm_available,
|
|
37
|
-
logging,
|
|
38
|
-
requires_backends,
|
|
39
|
-
)
|
|
40
|
-
from ...utils.backbone_utils import load_backbone
|
|
35
|
+
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
|
|
36
|
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
37
|
+
from ...processing_utils import Unpack
|
|
38
|
+
from ...pytorch_utils import compile_compatible_method_lru_cache, meshgrid
|
|
39
|
+
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check
|
|
40
|
+
from ...utils.generic import OutputRecorder, can_return_tuple, check_model_inputs
|
|
41
41
|
from .configuration_deformable_detr import DeformableDetrConfig
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
logger = logging.get_logger(__name__)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
if is_timm_available():
|
|
48
|
-
from timm import create_model
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
logger = logging.get_logger(__name__)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
|
55
|
-
class MultiScaleDeformableAttention(nn.Module):
|
|
56
|
-
def forward(
|
|
57
|
-
self,
|
|
58
|
-
value: Tensor,
|
|
59
|
-
value_spatial_shapes: Tensor,
|
|
60
|
-
value_spatial_shapes_list: list[tuple],
|
|
61
|
-
level_start_index: Tensor,
|
|
62
|
-
sampling_locations: Tensor,
|
|
63
|
-
attention_weights: Tensor,
|
|
64
|
-
im2col_step: int,
|
|
65
|
-
):
|
|
66
|
-
batch_size, _, num_heads, hidden_dim = value.shape
|
|
67
|
-
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
|
68
|
-
value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
|
|
69
|
-
sampling_grids = 2 * sampling_locations - 1
|
|
70
|
-
sampling_value_list = []
|
|
71
|
-
for level_id, (height, width) in enumerate(value_spatial_shapes_list):
|
|
72
|
-
# batch_size, height*width, num_heads, hidden_dim
|
|
73
|
-
# -> batch_size, height*width, num_heads*hidden_dim
|
|
74
|
-
# -> batch_size, num_heads*hidden_dim, height*width
|
|
75
|
-
# -> batch_size*num_heads, hidden_dim, height, width
|
|
76
|
-
value_l_ = (
|
|
77
|
-
value_list[level_id]
|
|
78
|
-
.flatten(2)
|
|
79
|
-
.transpose(1, 2)
|
|
80
|
-
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
|
81
|
-
)
|
|
82
|
-
# batch_size, num_queries, num_heads, num_points, 2
|
|
83
|
-
# -> batch_size, num_heads, num_queries, num_points, 2
|
|
84
|
-
# -> batch_size*num_heads, num_queries, num_points, 2
|
|
85
|
-
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
|
86
|
-
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
|
87
|
-
sampling_value_l_ = nn.functional.grid_sample(
|
|
88
|
-
value_l_,
|
|
89
|
-
sampling_grid_l_,
|
|
90
|
-
mode="bilinear",
|
|
91
|
-
padding_mode="zeros",
|
|
92
|
-
align_corners=False,
|
|
93
|
-
)
|
|
94
|
-
sampling_value_list.append(sampling_value_l_)
|
|
95
|
-
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
|
96
|
-
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
|
97
|
-
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
|
98
|
-
attention_weights = attention_weights.transpose(1, 2).reshape(
|
|
99
|
-
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
|
100
|
-
)
|
|
101
|
-
output = (
|
|
102
|
-
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
|
103
|
-
.sum(-1)
|
|
104
|
-
.view(batch_size, num_heads * hidden_dim, num_queries)
|
|
105
|
-
)
|
|
106
|
-
return output.transpose(1, 2).contiguous()
|
|
107
|
-
|
|
108
|
-
|
|
109
44
|
@dataclass
|
|
110
45
|
@auto_docstring(
|
|
111
46
|
custom_intro="""
|
|
112
|
-
Base class for outputs of the
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
- a stacked tensor of intermediate reference points.
|
|
47
|
+
Base class for outputs of the DEFORMABLE_DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
|
|
48
|
+
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
|
|
49
|
+
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
|
|
116
50
|
"""
|
|
117
51
|
)
|
|
118
|
-
class DeformableDetrDecoderOutput(
|
|
52
|
+
class DeformableDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
|
|
119
53
|
r"""
|
|
120
|
-
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
|
121
|
-
Stacked intermediate hidden states (output of each layer of the decoder).
|
|
122
|
-
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
|
|
123
|
-
Stacked intermediate reference points (reference points of each layer of the decoder).
|
|
124
54
|
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
|
125
55
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
126
56
|
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
|
127
57
|
used to compute the weighted average in the cross-attention heads.
|
|
58
|
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
|
59
|
+
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
|
60
|
+
layernorm.
|
|
61
|
+
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
|
|
62
|
+
Stacked intermediate reference points (reference points of each layer of the decoder).
|
|
128
63
|
"""
|
|
129
64
|
|
|
130
|
-
last_hidden_state: torch.FloatTensor | None = None
|
|
131
65
|
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
66
|
+
|
|
132
67
|
intermediate_reference_points: torch.FloatTensor | None = None
|
|
133
|
-
hidden_states: tuple[torch.FloatTensor] | None = None
|
|
134
|
-
attentions: tuple[torch.FloatTensor] | None = None
|
|
135
|
-
cross_attentions: tuple[torch.FloatTensor] | None = None
|
|
136
68
|
|
|
137
69
|
|
|
138
70
|
@dataclass
|
|
@@ -198,10 +130,10 @@ class DeformableDetrObjectDetectionOutput(ModelOutput):
|
|
|
198
130
|
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
|
|
199
131
|
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
|
|
200
132
|
`pred_boxes`) for each decoder layer.
|
|
201
|
-
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
|
202
|
-
Initial reference points sent through the Transformer decoder.
|
|
203
133
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
204
134
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
|
135
|
+
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
|
136
|
+
Initial reference points sent through the Transformer decoder.
|
|
205
137
|
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
|
206
138
|
Stacked intermediate hidden states (output of each layer of the decoder).
|
|
207
139
|
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
@@ -219,28 +151,76 @@ class DeformableDetrObjectDetectionOutput(ModelOutput):
|
|
|
219
151
|
logits: torch.FloatTensor | None = None
|
|
220
152
|
pred_boxes: torch.FloatTensor | None = None
|
|
221
153
|
auxiliary_outputs: list[dict] | None = None
|
|
222
|
-
init_reference_points: torch.FloatTensor | None = None
|
|
223
154
|
last_hidden_state: torch.FloatTensor | None = None
|
|
224
|
-
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
225
|
-
intermediate_reference_points: torch.FloatTensor | None = None
|
|
226
155
|
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
227
156
|
decoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
228
157
|
cross_attentions: tuple[torch.FloatTensor] | None = None
|
|
229
158
|
encoder_last_hidden_state: torch.FloatTensor | None = None
|
|
230
159
|
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
231
160
|
encoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
232
|
-
|
|
161
|
+
|
|
162
|
+
init_reference_points: torch.FloatTensor | None = None
|
|
163
|
+
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
164
|
+
intermediate_reference_points: torch.FloatTensor | None = None
|
|
165
|
+
enc_outputs_class: torch.FloatTensor | None = None
|
|
233
166
|
enc_outputs_coord_logits: torch.FloatTensor | None = None
|
|
234
167
|
|
|
235
168
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
169
|
+
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
|
170
|
+
class MultiScaleDeformableAttention(nn.Module):
|
|
171
|
+
def forward(
|
|
172
|
+
self,
|
|
173
|
+
value: Tensor,
|
|
174
|
+
value_spatial_shapes: Tensor,
|
|
175
|
+
value_spatial_shapes_list: list[tuple],
|
|
176
|
+
level_start_index: Tensor,
|
|
177
|
+
sampling_locations: Tensor,
|
|
178
|
+
attention_weights: Tensor,
|
|
179
|
+
im2col_step: int,
|
|
180
|
+
):
|
|
181
|
+
batch_size, _, num_heads, hidden_dim = value.shape
|
|
182
|
+
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
|
183
|
+
value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
|
|
184
|
+
sampling_grids = 2 * sampling_locations - 1
|
|
185
|
+
sampling_value_list = []
|
|
186
|
+
for level_id, (height, width) in enumerate(value_spatial_shapes_list):
|
|
187
|
+
# batch_size, height*width, num_heads, hidden_dim
|
|
188
|
+
# -> batch_size, height*width, num_heads*hidden_dim
|
|
189
|
+
# -> batch_size, num_heads*hidden_dim, height*width
|
|
190
|
+
# -> batch_size*num_heads, hidden_dim, height, width
|
|
191
|
+
value_l_ = (
|
|
192
|
+
value_list[level_id]
|
|
193
|
+
.flatten(2)
|
|
194
|
+
.transpose(1, 2)
|
|
195
|
+
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
|
196
|
+
)
|
|
197
|
+
# batch_size, num_queries, num_heads, num_points, 2
|
|
198
|
+
# -> batch_size, num_heads, num_queries, num_points, 2
|
|
199
|
+
# -> batch_size*num_heads, num_queries, num_points, 2
|
|
200
|
+
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
|
201
|
+
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
|
202
|
+
sampling_value_l_ = nn.functional.grid_sample(
|
|
203
|
+
value_l_,
|
|
204
|
+
sampling_grid_l_,
|
|
205
|
+
mode="bilinear",
|
|
206
|
+
padding_mode="zeros",
|
|
207
|
+
align_corners=False,
|
|
208
|
+
)
|
|
209
|
+
sampling_value_list.append(sampling_value_l_)
|
|
210
|
+
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
|
211
|
+
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
|
212
|
+
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
|
213
|
+
attention_weights = attention_weights.transpose(1, 2).reshape(
|
|
214
|
+
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
|
215
|
+
)
|
|
216
|
+
output = (
|
|
217
|
+
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
|
218
|
+
.sum(-1)
|
|
219
|
+
.view(batch_size, num_heads * hidden_dim, num_queries)
|
|
220
|
+
)
|
|
221
|
+
return output.transpose(1, 2).contiguous()
|
|
241
222
|
|
|
242
223
|
|
|
243
|
-
# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->DeformableDetr
|
|
244
224
|
class DeformableDetrFrozenBatchNorm2d(nn.Module):
|
|
245
225
|
"""
|
|
246
226
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
@@ -280,7 +260,6 @@ class DeformableDetrFrozenBatchNorm2d(nn.Module):
|
|
|
280
260
|
return x * scale + bias
|
|
281
261
|
|
|
282
262
|
|
|
283
|
-
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DeformableDetr
|
|
284
263
|
def replace_batch_norm(model):
|
|
285
264
|
r"""
|
|
286
265
|
Recursively replace all `torch.nn.BatchNorm2d` with `DeformableDetrFrozenBatchNorm2d`.
|
|
@@ -318,57 +297,36 @@ class DeformableDetrConvEncoder(nn.Module):
|
|
|
318
297
|
|
|
319
298
|
self.config = config
|
|
320
299
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
# We default to values which were previously hard-coded. This enables configurability from the config
|
|
324
|
-
# using backbone arguments, while keeping the default behavior the same.
|
|
325
|
-
requires_backends(self, ["timm"])
|
|
326
|
-
kwargs = getattr(config, "backbone_kwargs", {})
|
|
327
|
-
kwargs = {} if kwargs is None else kwargs.copy()
|
|
328
|
-
out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,))
|
|
329
|
-
num_channels = kwargs.pop("in_chans", config.num_channels)
|
|
330
|
-
if config.dilation:
|
|
331
|
-
kwargs["output_stride"] = kwargs.get("output_stride", 16)
|
|
332
|
-
backbone = create_model(
|
|
333
|
-
config.backbone,
|
|
334
|
-
pretrained=config.use_pretrained_backbone,
|
|
335
|
-
features_only=True,
|
|
336
|
-
out_indices=out_indices,
|
|
337
|
-
in_chans=num_channels,
|
|
338
|
-
**kwargs,
|
|
339
|
-
)
|
|
340
|
-
else:
|
|
341
|
-
backbone = load_backbone(config)
|
|
300
|
+
backbone = load_backbone(config)
|
|
301
|
+
self.intermediate_channel_sizes = backbone.channels
|
|
342
302
|
|
|
343
303
|
# replace batch norm by frozen batch norm
|
|
344
304
|
with torch.no_grad():
|
|
345
305
|
replace_batch_norm(backbone)
|
|
346
|
-
self.model = backbone
|
|
347
|
-
self.intermediate_channel_sizes = (
|
|
348
|
-
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
|
|
349
|
-
)
|
|
350
306
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
307
|
+
# We used to load with timm library directly instead of the AutoBackbone API
|
|
308
|
+
# so we need to unwrap the `backbone._backbone` module to load weights without mismatch
|
|
309
|
+
is_timm_model = False
|
|
310
|
+
if hasattr(backbone, "_backbone"):
|
|
311
|
+
backbone = backbone._backbone
|
|
312
|
+
is_timm_model = True
|
|
313
|
+
self.model = backbone
|
|
358
314
|
|
|
315
|
+
backbone_model_type = config.backbone_config.model_type
|
|
359
316
|
if "resnet" in backbone_model_type:
|
|
360
317
|
for name, parameter in self.model.named_parameters():
|
|
361
|
-
if
|
|
318
|
+
if is_timm_model:
|
|
362
319
|
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
|
363
320
|
parameter.requires_grad_(False)
|
|
364
321
|
else:
|
|
365
322
|
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
|
|
366
323
|
parameter.requires_grad_(False)
|
|
367
324
|
|
|
368
|
-
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder.forward with Detr->DeformableDetr
|
|
369
325
|
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
|
370
326
|
# send pixel_values through the model to get list of feature maps
|
|
371
|
-
features = self.model(pixel_values)
|
|
327
|
+
features = self.model(pixel_values)
|
|
328
|
+
if isinstance(features, dict):
|
|
329
|
+
features = features.feature_maps
|
|
372
330
|
|
|
373
331
|
out = []
|
|
374
332
|
for feature_map in features:
|
|
@@ -378,67 +336,58 @@ class DeformableDetrConvEncoder(nn.Module):
|
|
|
378
336
|
return out
|
|
379
337
|
|
|
380
338
|
|
|
381
|
-
# Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->DeformableDetr
|
|
382
|
-
class DeformableDetrConvModel(nn.Module):
|
|
383
|
-
"""
|
|
384
|
-
This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
|
|
385
|
-
"""
|
|
386
|
-
|
|
387
|
-
def __init__(self, conv_encoder, position_embedding):
|
|
388
|
-
super().__init__()
|
|
389
|
-
self.conv_encoder = conv_encoder
|
|
390
|
-
self.position_embedding = position_embedding
|
|
391
|
-
|
|
392
|
-
def forward(self, pixel_values, pixel_mask):
|
|
393
|
-
# send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
|
|
394
|
-
out = self.conv_encoder(pixel_values, pixel_mask)
|
|
395
|
-
pos = []
|
|
396
|
-
for feature_map, mask in out:
|
|
397
|
-
# position encoding
|
|
398
|
-
pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
|
|
399
|
-
|
|
400
|
-
return out, pos
|
|
401
|
-
|
|
402
|
-
|
|
403
339
|
class DeformableDetrSinePositionEmbedding(nn.Module):
|
|
404
340
|
"""
|
|
405
341
|
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
|
|
406
342
|
need paper, generalized to work on images.
|
|
407
343
|
"""
|
|
408
344
|
|
|
409
|
-
def __init__(
|
|
345
|
+
def __init__(
|
|
346
|
+
self,
|
|
347
|
+
num_position_features: int = 64,
|
|
348
|
+
temperature: int = 10000,
|
|
349
|
+
normalize: bool = False,
|
|
350
|
+
scale: float | None = None,
|
|
351
|
+
):
|
|
410
352
|
super().__init__()
|
|
411
|
-
self.embedding_dim = embedding_dim
|
|
412
|
-
self.temperature = temperature
|
|
413
|
-
self.normalize = normalize
|
|
414
353
|
if scale is not None and normalize is False:
|
|
415
354
|
raise ValueError("normalize should be True if scale is passed")
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
self.
|
|
355
|
+
self.num_position_features = num_position_features
|
|
356
|
+
self.temperature = temperature
|
|
357
|
+
self.normalize = normalize
|
|
358
|
+
self.scale = 2 * math.pi if scale is None else scale
|
|
419
359
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
360
|
+
@compile_compatible_method_lru_cache(maxsize=1)
|
|
361
|
+
def forward(
|
|
362
|
+
self,
|
|
363
|
+
shape: torch.Size,
|
|
364
|
+
device: torch.device | str,
|
|
365
|
+
dtype: torch.dtype,
|
|
366
|
+
mask: torch.Tensor | None = None,
|
|
367
|
+
) -> torch.Tensor:
|
|
368
|
+
if mask is None:
|
|
369
|
+
mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
|
|
370
|
+
y_embed = mask.cumsum(1, dtype=dtype)
|
|
371
|
+
x_embed = mask.cumsum(2, dtype=dtype)
|
|
425
372
|
if self.normalize:
|
|
426
373
|
eps = 1e-6
|
|
427
374
|
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
|
428
375
|
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
|
429
376
|
|
|
430
|
-
dim_t = torch.arange(self.
|
|
431
|
-
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.
|
|
377
|
+
dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
|
|
378
|
+
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
|
|
432
379
|
|
|
433
380
|
pos_x = x_embed[:, :, :, None] / dim_t
|
|
434
381
|
pos_y = y_embed[:, :, :, None] / dim_t
|
|
435
382
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
436
383
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
437
384
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
385
|
+
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
|
|
386
|
+
# expected by the encoder
|
|
387
|
+
pos = pos.flatten(2).permute(0, 2, 1)
|
|
438
388
|
return pos
|
|
439
389
|
|
|
440
390
|
|
|
441
|
-
# Copied from transformers.models.detr.modeling_detr.DetrLearnedPositionEmbedding
|
|
442
391
|
class DeformableDetrLearnedPositionEmbedding(nn.Module):
|
|
443
392
|
"""
|
|
444
393
|
This module learns positional embeddings up to a fixed maximum size.
|
|
@@ -449,31 +398,122 @@ class DeformableDetrLearnedPositionEmbedding(nn.Module):
|
|
|
449
398
|
self.row_embeddings = nn.Embedding(50, embedding_dim)
|
|
450
399
|
self.column_embeddings = nn.Embedding(50, embedding_dim)
|
|
451
400
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
401
|
+
@compile_compatible_method_lru_cache(maxsize=1)
|
|
402
|
+
def forward(
|
|
403
|
+
self,
|
|
404
|
+
shape: torch.Size,
|
|
405
|
+
device: torch.device | str,
|
|
406
|
+
dtype: torch.dtype,
|
|
407
|
+
mask: torch.Tensor | None = None,
|
|
408
|
+
):
|
|
409
|
+
height, width = shape[-2:]
|
|
410
|
+
width_values = torch.arange(width, device=device)
|
|
411
|
+
height_values = torch.arange(height, device=device)
|
|
456
412
|
x_emb = self.column_embeddings(width_values)
|
|
457
413
|
y_emb = self.row_embeddings(height_values)
|
|
458
414
|
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
|
|
459
415
|
pos = pos.permute(2, 0, 1)
|
|
460
416
|
pos = pos.unsqueeze(0)
|
|
461
|
-
pos = pos.repeat(
|
|
417
|
+
pos = pos.repeat(shape[0], 1, 1, 1)
|
|
418
|
+
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
|
|
419
|
+
# expected by the encoder
|
|
420
|
+
pos = pos.flatten(2).permute(0, 2, 1)
|
|
462
421
|
return pos
|
|
463
422
|
|
|
464
423
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
424
|
+
def eager_attention_forward(
|
|
425
|
+
module: nn.Module,
|
|
426
|
+
query: torch.Tensor,
|
|
427
|
+
key: torch.Tensor,
|
|
428
|
+
value: torch.Tensor,
|
|
429
|
+
attention_mask: torch.Tensor | None,
|
|
430
|
+
scaling: float | None = None,
|
|
431
|
+
dropout: float = 0.0,
|
|
432
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
433
|
+
):
|
|
434
|
+
if scaling is None:
|
|
435
|
+
scaling = query.size(-1) ** -0.5
|
|
436
|
+
|
|
437
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
438
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
439
|
+
|
|
440
|
+
if attention_mask is not None:
|
|
441
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
442
|
+
attn_weights = attn_weights + attention_mask
|
|
443
|
+
|
|
444
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
445
|
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
446
|
+
|
|
447
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
448
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
449
|
+
|
|
450
|
+
return attn_output, attn_weights
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
class DeformableDetrSelfAttention(nn.Module):
|
|
454
|
+
"""
|
|
455
|
+
Multi-headed self-attention from 'Attention Is All You Need' paper.
|
|
456
|
+
|
|
457
|
+
In DEFORMABLE_DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
|
|
458
|
+
"""
|
|
459
|
+
|
|
460
|
+
def __init__(
|
|
461
|
+
self,
|
|
462
|
+
config: DeformableDetrConfig,
|
|
463
|
+
hidden_size: int,
|
|
464
|
+
num_attention_heads: int,
|
|
465
|
+
dropout: float = 0.0,
|
|
466
|
+
bias: bool = True,
|
|
467
|
+
):
|
|
468
|
+
super().__init__()
|
|
469
|
+
self.config = config
|
|
470
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
471
|
+
self.scaling = self.head_dim**-0.5
|
|
472
|
+
self.attention_dropout = dropout
|
|
473
|
+
self.is_causal = False
|
|
474
|
+
|
|
475
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
476
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
477
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
478
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
479
|
+
|
|
480
|
+
def forward(
|
|
481
|
+
self,
|
|
482
|
+
hidden_states: torch.Tensor,
|
|
483
|
+
attention_mask: torch.Tensor | None = None,
|
|
484
|
+
position_embeddings: torch.Tensor | None = None,
|
|
485
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
486
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
487
|
+
"""
|
|
488
|
+
Position embeddings are added to both queries and keys (but not values).
|
|
489
|
+
"""
|
|
490
|
+
input_shape = hidden_states.shape[:-1]
|
|
491
|
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
492
|
+
|
|
493
|
+
query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
|
|
494
|
+
|
|
495
|
+
query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
496
|
+
key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
497
|
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
498
|
+
|
|
499
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
500
|
+
self.config._attn_implementation, eager_attention_forward
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
attn_output, attn_weights = attention_interface(
|
|
504
|
+
self,
|
|
505
|
+
query_states,
|
|
506
|
+
key_states,
|
|
507
|
+
value_states,
|
|
508
|
+
attention_mask,
|
|
509
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
510
|
+
scaling=self.scaling,
|
|
511
|
+
**kwargs,
|
|
512
|
+
)
|
|
475
513
|
|
|
476
|
-
|
|
514
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
515
|
+
attn_output = self.o_proj(attn_output)
|
|
516
|
+
return attn_output, attn_weights
|
|
477
517
|
|
|
478
518
|
|
|
479
519
|
class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
|
@@ -513,9 +553,6 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
|
|
513
553
|
|
|
514
554
|
self.disable_custom_kernels = config.disable_custom_kernels
|
|
515
555
|
|
|
516
|
-
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Tensor | None):
|
|
517
|
-
return tensor if position_embeddings is None else tensor + position_embeddings
|
|
518
|
-
|
|
519
556
|
def forward(
|
|
520
557
|
self,
|
|
521
558
|
hidden_states: torch.Tensor,
|
|
@@ -527,19 +564,19 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
|
|
527
564
|
spatial_shapes=None,
|
|
528
565
|
spatial_shapes_list=None,
|
|
529
566
|
level_start_index=None,
|
|
530
|
-
|
|
531
|
-
):
|
|
567
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
568
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
532
569
|
# add position embeddings to the hidden states before projecting to queries and keys
|
|
533
570
|
if position_embeddings is not None:
|
|
534
|
-
hidden_states =
|
|
571
|
+
hidden_states = hidden_states + position_embeddings
|
|
535
572
|
|
|
536
573
|
batch_size, num_queries, _ = hidden_states.shape
|
|
537
574
|
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
|
538
575
|
total_elements = sum(height * width for height, width in spatial_shapes_list)
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
576
|
+
torch_compilable_check(
|
|
577
|
+
total_elements == sequence_length,
|
|
578
|
+
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
|
|
579
|
+
)
|
|
543
580
|
|
|
544
581
|
value = self.value_proj(encoder_hidden_states)
|
|
545
582
|
if attention_mask is not None:
|
|
@@ -586,159 +623,48 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
|
|
586
623
|
return output, attention_weights
|
|
587
624
|
|
|
588
625
|
|
|
589
|
-
class
|
|
590
|
-
|
|
591
|
-
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
592
|
-
|
|
593
|
-
Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
|
|
594
|
-
"""
|
|
595
|
-
|
|
596
|
-
def __init__(
|
|
597
|
-
self,
|
|
598
|
-
embed_dim: int,
|
|
599
|
-
num_heads: int,
|
|
600
|
-
dropout: float = 0.0,
|
|
601
|
-
bias: bool = True,
|
|
602
|
-
):
|
|
626
|
+
class DeformableDetrMLP(nn.Module):
|
|
627
|
+
def __init__(self, config: DeformableDetrConfig, hidden_size: int, intermediate_size: int):
|
|
603
628
|
super().__init__()
|
|
604
|
-
self.
|
|
605
|
-
self.
|
|
606
|
-
self.
|
|
607
|
-
self.
|
|
608
|
-
|
|
609
|
-
raise ValueError(
|
|
610
|
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
611
|
-
f" {num_heads})."
|
|
612
|
-
)
|
|
613
|
-
self.scaling = self.head_dim**-0.5
|
|
614
|
-
|
|
615
|
-
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
616
|
-
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
617
|
-
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
618
|
-
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
619
|
-
|
|
620
|
-
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
|
621
|
-
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
622
|
-
|
|
623
|
-
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Tensor | None):
|
|
624
|
-
return tensor if position_embeddings is None else tensor + position_embeddings
|
|
625
|
-
|
|
626
|
-
def forward(
|
|
627
|
-
self,
|
|
628
|
-
hidden_states: torch.Tensor,
|
|
629
|
-
attention_mask: torch.Tensor | None = None,
|
|
630
|
-
position_embeddings: torch.Tensor | None = None,
|
|
631
|
-
output_attentions: bool = False,
|
|
632
|
-
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
|
633
|
-
"""Input shape: Batch x Time x Channel"""
|
|
634
|
-
|
|
635
|
-
batch_size, target_len, embed_dim = hidden_states.size()
|
|
636
|
-
# add position embeddings to the hidden states before projecting to queries and keys
|
|
637
|
-
if position_embeddings is not None:
|
|
638
|
-
hidden_states_original = hidden_states
|
|
639
|
-
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
|
640
|
-
|
|
641
|
-
# get queries, keys and values
|
|
642
|
-
query_states = self.q_proj(hidden_states) * self.scaling
|
|
643
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
|
|
644
|
-
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
|
|
645
|
-
|
|
646
|
-
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
|
|
647
|
-
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
|
|
648
|
-
key_states = key_states.view(*proj_shape)
|
|
649
|
-
value_states = value_states.view(*proj_shape)
|
|
650
|
-
|
|
651
|
-
source_len = key_states.size(1)
|
|
652
|
-
|
|
653
|
-
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
|
654
|
-
|
|
655
|
-
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
|
|
656
|
-
raise ValueError(
|
|
657
|
-
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
|
|
658
|
-
f" {attn_weights.size()}"
|
|
659
|
-
)
|
|
660
|
-
|
|
661
|
-
# expand attention_mask
|
|
662
|
-
if attention_mask is not None:
|
|
663
|
-
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
664
|
-
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
|
665
|
-
|
|
666
|
-
if attention_mask is not None:
|
|
667
|
-
if attention_mask.size() != (batch_size, 1, target_len, source_len):
|
|
668
|
-
raise ValueError(
|
|
669
|
-
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
|
670
|
-
f" {attention_mask.size()}"
|
|
671
|
-
)
|
|
672
|
-
if attention_mask.dtype == torch.bool:
|
|
673
|
-
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
|
|
674
|
-
attention_mask, -torch.inf
|
|
675
|
-
)
|
|
676
|
-
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
|
677
|
-
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
|
678
|
-
|
|
679
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
680
|
-
|
|
681
|
-
if output_attentions:
|
|
682
|
-
# this operation is a bit awkward, but it's required to
|
|
683
|
-
# make sure that attn_weights keeps its gradient.
|
|
684
|
-
# In order to do so, attn_weights have to reshaped
|
|
685
|
-
# twice and have to be reused in the following
|
|
686
|
-
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
|
687
|
-
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
|
688
|
-
else:
|
|
689
|
-
attn_weights_reshaped = None
|
|
690
|
-
|
|
691
|
-
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
692
|
-
|
|
693
|
-
attn_output = torch.bmm(attn_probs, value_states)
|
|
694
|
-
|
|
695
|
-
if attn_output.size() != (
|
|
696
|
-
batch_size * self.num_heads,
|
|
697
|
-
target_len,
|
|
698
|
-
self.head_dim,
|
|
699
|
-
):
|
|
700
|
-
raise ValueError(
|
|
701
|
-
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
|
|
702
|
-
f" {attn_output.size()}"
|
|
703
|
-
)
|
|
704
|
-
|
|
705
|
-
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
|
|
706
|
-
attn_output = attn_output.transpose(1, 2)
|
|
707
|
-
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
|
|
708
|
-
|
|
709
|
-
attn_output = self.out_proj(attn_output)
|
|
629
|
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
|
630
|
+
self.fc2 = nn.Linear(intermediate_size, hidden_size)
|
|
631
|
+
self.activation_fn = ACT2FN[config.activation_function]
|
|
632
|
+
self.activation_dropout = config.activation_dropout
|
|
633
|
+
self.dropout = config.dropout
|
|
710
634
|
|
|
711
|
-
|
|
635
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
636
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
637
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
638
|
+
hidden_states = self.fc2(hidden_states)
|
|
639
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
640
|
+
return hidden_states
|
|
712
641
|
|
|
713
642
|
|
|
714
643
|
class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
|
|
715
644
|
def __init__(self, config: DeformableDetrConfig):
|
|
716
645
|
super().__init__()
|
|
717
|
-
self.
|
|
646
|
+
self.hidden_size = config.d_model
|
|
718
647
|
self.self_attn = DeformableDetrMultiscaleDeformableAttention(
|
|
719
648
|
config,
|
|
720
649
|
num_heads=config.encoder_attention_heads,
|
|
721
650
|
n_points=config.encoder_n_points,
|
|
722
651
|
)
|
|
723
|
-
self.self_attn_layer_norm = nn.LayerNorm(self.
|
|
652
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
724
653
|
self.dropout = config.dropout
|
|
725
|
-
self.
|
|
726
|
-
self.
|
|
727
|
-
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
|
728
|
-
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
|
729
|
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
654
|
+
self.mlp = DeformableDetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
|
|
655
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
730
656
|
|
|
731
657
|
def forward(
|
|
732
658
|
self,
|
|
733
659
|
hidden_states: torch.Tensor,
|
|
734
660
|
attention_mask: torch.Tensor,
|
|
735
|
-
|
|
661
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
736
662
|
reference_points=None,
|
|
737
663
|
spatial_shapes=None,
|
|
738
664
|
spatial_shapes_list=None,
|
|
739
665
|
level_start_index=None,
|
|
740
|
-
|
|
741
|
-
):
|
|
666
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
667
|
+
) -> torch.Tensor:
|
|
742
668
|
"""
|
|
743
669
|
Args:
|
|
744
670
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
@@ -753,24 +679,18 @@ class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
|
|
|
753
679
|
Spatial shapes of the backbone feature maps.
|
|
754
680
|
level_start_index (`torch.LongTensor`, *optional*):
|
|
755
681
|
Level start index.
|
|
756
|
-
output_attentions (`bool`, *optional*):
|
|
757
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
758
|
-
returned tensors for more detail.
|
|
759
682
|
"""
|
|
760
683
|
residual = hidden_states
|
|
761
|
-
|
|
762
|
-
# Apply Multi-scale Deformable Attention Module on the multi-scale feature maps.
|
|
763
|
-
hidden_states, attn_weights = self.self_attn(
|
|
684
|
+
hidden_states, _ = self.self_attn(
|
|
764
685
|
hidden_states=hidden_states,
|
|
765
686
|
attention_mask=attention_mask,
|
|
766
687
|
encoder_hidden_states=hidden_states,
|
|
767
688
|
encoder_attention_mask=attention_mask,
|
|
768
|
-
position_embeddings=
|
|
689
|
+
position_embeddings=spatial_position_embeddings,
|
|
769
690
|
reference_points=reference_points,
|
|
770
691
|
spatial_shapes=spatial_shapes,
|
|
771
692
|
spatial_shapes_list=spatial_shapes_list,
|
|
772
693
|
level_start_index=level_start_index,
|
|
773
|
-
output_attentions=output_attentions,
|
|
774
694
|
)
|
|
775
695
|
|
|
776
696
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
@@ -778,12 +698,7 @@ class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
|
|
|
778
698
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
779
699
|
|
|
780
700
|
residual = hidden_states
|
|
781
|
-
hidden_states = self.
|
|
782
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
783
|
-
|
|
784
|
-
hidden_states = self.fc2(hidden_states)
|
|
785
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
786
|
-
|
|
701
|
+
hidden_states = self.mlp(hidden_states)
|
|
787
702
|
hidden_states = residual + hidden_states
|
|
788
703
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
789
704
|
|
|
@@ -792,54 +707,44 @@ class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
|
|
|
792
707
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
793
708
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
794
709
|
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
if output_attentions:
|
|
798
|
-
outputs += (attn_weights,)
|
|
799
|
-
|
|
800
|
-
return outputs
|
|
710
|
+
return hidden_states
|
|
801
711
|
|
|
802
712
|
|
|
803
713
|
class DeformableDetrDecoderLayer(GradientCheckpointingLayer):
|
|
804
714
|
def __init__(self, config: DeformableDetrConfig):
|
|
805
715
|
super().__init__()
|
|
806
|
-
self.
|
|
716
|
+
self.hidden_size = config.d_model
|
|
807
717
|
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
718
|
+
self.self_attn = DeformableDetrSelfAttention(
|
|
719
|
+
config=config,
|
|
720
|
+
hidden_size=self.hidden_size,
|
|
721
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
812
722
|
dropout=config.attention_dropout,
|
|
813
723
|
)
|
|
814
724
|
self.dropout = config.dropout
|
|
815
|
-
self.activation_fn = ACT2FN[config.activation_function]
|
|
816
|
-
self.activation_dropout = config.activation_dropout
|
|
817
725
|
|
|
818
|
-
self.self_attn_layer_norm = nn.LayerNorm(self.
|
|
819
|
-
# cross-attention
|
|
726
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
820
727
|
self.encoder_attn = DeformableDetrMultiscaleDeformableAttention(
|
|
821
728
|
config,
|
|
822
729
|
num_heads=config.decoder_attention_heads,
|
|
823
730
|
n_points=config.decoder_n_points,
|
|
824
731
|
)
|
|
825
|
-
self.encoder_attn_layer_norm = nn.LayerNorm(self.
|
|
826
|
-
|
|
827
|
-
self.
|
|
828
|
-
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
|
829
|
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
732
|
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
733
|
+
self.mlp = DeformableDetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
|
|
734
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
830
735
|
|
|
831
736
|
def forward(
|
|
832
737
|
self,
|
|
833
738
|
hidden_states: torch.Tensor,
|
|
834
|
-
|
|
739
|
+
object_queries_position_embeddings: torch.Tensor | None = None,
|
|
835
740
|
reference_points=None,
|
|
836
741
|
spatial_shapes=None,
|
|
837
742
|
spatial_shapes_list=None,
|
|
838
743
|
level_start_index=None,
|
|
839
744
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
840
745
|
encoder_attention_mask: torch.Tensor | None = None,
|
|
841
|
-
|
|
842
|
-
):
|
|
746
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
747
|
+
) -> torch.Tensor:
|
|
843
748
|
"""
|
|
844
749
|
Args:
|
|
845
750
|
hidden_states (`torch.FloatTensor`):
|
|
@@ -857,60 +762,47 @@ class DeformableDetrDecoderLayer(GradientCheckpointingLayer):
|
|
|
857
762
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
858
763
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
859
764
|
values.
|
|
860
|
-
output_attentions (`bool`, *optional*):
|
|
861
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
862
|
-
returned tensors for more detail.
|
|
863
765
|
"""
|
|
864
766
|
residual = hidden_states
|
|
865
767
|
|
|
866
768
|
# Self Attention
|
|
867
|
-
hidden_states,
|
|
769
|
+
hidden_states, _ = self.self_attn(
|
|
868
770
|
hidden_states=hidden_states,
|
|
869
|
-
position_embeddings=
|
|
870
|
-
|
|
771
|
+
position_embeddings=object_queries_position_embeddings,
|
|
772
|
+
**kwargs,
|
|
871
773
|
)
|
|
872
774
|
|
|
873
775
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
874
776
|
hidden_states = residual + hidden_states
|
|
875
777
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
876
778
|
|
|
877
|
-
|
|
779
|
+
residual = hidden_states
|
|
878
780
|
|
|
879
781
|
# Cross-Attention
|
|
880
|
-
|
|
881
|
-
hidden_states, cross_attn_weights = self.encoder_attn(
|
|
782
|
+
hidden_states, _ = self.encoder_attn(
|
|
882
783
|
hidden_states=hidden_states,
|
|
883
784
|
attention_mask=encoder_attention_mask,
|
|
884
785
|
encoder_hidden_states=encoder_hidden_states,
|
|
885
786
|
encoder_attention_mask=encoder_attention_mask,
|
|
886
|
-
position_embeddings=
|
|
787
|
+
position_embeddings=object_queries_position_embeddings,
|
|
887
788
|
reference_points=reference_points,
|
|
888
789
|
spatial_shapes=spatial_shapes,
|
|
889
790
|
spatial_shapes_list=spatial_shapes_list,
|
|
890
791
|
level_start_index=level_start_index,
|
|
891
|
-
output_attentions=output_attentions,
|
|
892
792
|
)
|
|
893
793
|
|
|
894
794
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
895
|
-
hidden_states =
|
|
795
|
+
hidden_states = residual + hidden_states
|
|
896
796
|
|
|
897
797
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
898
798
|
|
|
899
799
|
# Fully Connected
|
|
900
800
|
residual = hidden_states
|
|
901
|
-
hidden_states = self.
|
|
902
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
903
|
-
hidden_states = self.fc2(hidden_states)
|
|
904
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
801
|
+
hidden_states = self.mlp(hidden_states)
|
|
905
802
|
hidden_states = residual + hidden_states
|
|
906
803
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
907
804
|
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
if output_attentions:
|
|
911
|
-
outputs += (self_attn_weights, cross_attn_weights)
|
|
912
|
-
|
|
913
|
-
return outputs
|
|
805
|
+
return hidden_states
|
|
914
806
|
|
|
915
807
|
|
|
916
808
|
@auto_docstring
|
|
@@ -925,6 +817,13 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
|
|
|
925
817
|
r"DeformableDetrEncoderLayer",
|
|
926
818
|
r"DeformableDetrDecoderLayer",
|
|
927
819
|
]
|
|
820
|
+
_supports_sdpa = True
|
|
821
|
+
_supports_flash_attn = True
|
|
822
|
+
_supports_attention_backend = True
|
|
823
|
+
_supports_flex_attn = True
|
|
824
|
+
_keys_to_ignore_on_load_unexpected = [
|
|
825
|
+
r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
|
|
826
|
+
]
|
|
928
827
|
|
|
929
828
|
@torch.no_grad()
|
|
930
829
|
def _init_weights(self, module):
|
|
@@ -982,9 +881,13 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
|
|
982
881
|
config: DeformableDetrConfig
|
|
983
882
|
"""
|
|
984
883
|
|
|
884
|
+
_can_record_outputs = {
|
|
885
|
+
"hidden_states": DeformableDetrEncoderLayer,
|
|
886
|
+
"attentions": OutputRecorder(DeformableDetrMultiscaleDeformableAttention, layer_name="self_attn", index=1),
|
|
887
|
+
}
|
|
888
|
+
|
|
985
889
|
def __init__(self, config: DeformableDetrConfig):
|
|
986
890
|
super().__init__(config)
|
|
987
|
-
self.gradient_checkpointing = False
|
|
988
891
|
|
|
989
892
|
self.dropout = config.dropout
|
|
990
893
|
self.layers = nn.ModuleList([DeformableDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
@@ -992,51 +895,18 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
|
|
992
895
|
# Initialize weights and apply final processing
|
|
993
896
|
self.post_init()
|
|
994
897
|
|
|
995
|
-
@
|
|
996
|
-
def get_reference_points(spatial_shapes, valid_ratios, device):
|
|
997
|
-
"""
|
|
998
|
-
Get reference points for each feature map. Used in decoder.
|
|
999
|
-
|
|
1000
|
-
Args:
|
|
1001
|
-
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
|
|
1002
|
-
Spatial shapes of each feature map.
|
|
1003
|
-
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
|
|
1004
|
-
Valid ratios of each feature map.
|
|
1005
|
-
device (`torch.device`):
|
|
1006
|
-
Device on which to create the tensors.
|
|
1007
|
-
Returns:
|
|
1008
|
-
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
|
|
1009
|
-
"""
|
|
1010
|
-
reference_points_list = []
|
|
1011
|
-
for level, (height, width) in enumerate(spatial_shapes):
|
|
1012
|
-
ref_y, ref_x = meshgrid(
|
|
1013
|
-
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
|
|
1014
|
-
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
|
|
1015
|
-
indexing="ij",
|
|
1016
|
-
)
|
|
1017
|
-
# TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
|
|
1018
|
-
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
|
|
1019
|
-
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width)
|
|
1020
|
-
ref = torch.stack((ref_x, ref_y), -1)
|
|
1021
|
-
reference_points_list.append(ref)
|
|
1022
|
-
reference_points = torch.cat(reference_points_list, 1)
|
|
1023
|
-
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
|
1024
|
-
return reference_points
|
|
1025
|
-
|
|
898
|
+
@check_model_inputs()
|
|
1026
899
|
def forward(
|
|
1027
900
|
self,
|
|
1028
901
|
inputs_embeds=None,
|
|
1029
902
|
attention_mask=None,
|
|
1030
|
-
|
|
903
|
+
spatial_position_embeddings=None,
|
|
1031
904
|
spatial_shapes=None,
|
|
1032
905
|
spatial_shapes_list=None,
|
|
1033
906
|
level_start_index=None,
|
|
1034
907
|
valid_ratios=None,
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
return_dict=None,
|
|
1038
|
-
**kwargs,
|
|
1039
|
-
):
|
|
908
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
909
|
+
) -> BaseModelOutput:
|
|
1040
910
|
r"""
|
|
1041
911
|
Args:
|
|
1042
912
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
@@ -1046,66 +916,72 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
|
|
1046
916
|
- 1 for pixel features that are real (i.e. **not masked**),
|
|
1047
917
|
- 0 for pixel features that are padding (i.e. **masked**).
|
|
1048
918
|
[What are attention masks?](../glossary#attention-mask)
|
|
1049
|
-
|
|
1050
|
-
|
|
919
|
+
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
920
|
+
Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
|
|
1051
921
|
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
|
|
1052
922
|
Spatial shapes of each feature map.
|
|
1053
923
|
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
|
|
1054
924
|
Starting index of each feature map.
|
|
1055
925
|
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
|
|
1056
926
|
Ratio of valid area in each feature level.
|
|
1057
|
-
output_attentions (`bool`, *optional*):
|
|
1058
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
1059
|
-
returned tensors for more detail.
|
|
1060
|
-
output_hidden_states (`bool`, *optional*):
|
|
1061
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
1062
|
-
for more detail.
|
|
1063
|
-
return_dict (`bool`, *optional*):
|
|
1064
|
-
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
1065
927
|
"""
|
|
1066
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1067
|
-
output_hidden_states = (
|
|
1068
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1069
|
-
)
|
|
1070
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1071
|
-
|
|
1072
928
|
hidden_states = inputs_embeds
|
|
1073
929
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
1074
930
|
|
|
1075
931
|
spatial_shapes_tuple = tuple(spatial_shapes_list)
|
|
1076
932
|
reference_points = self.get_reference_points(spatial_shapes_tuple, valid_ratios, device=inputs_embeds.device)
|
|
1077
933
|
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
for i, encoder_layer in enumerate(self.layers):
|
|
1081
|
-
if output_hidden_states:
|
|
1082
|
-
encoder_states = encoder_states + (hidden_states,)
|
|
1083
|
-
layer_outputs = encoder_layer(
|
|
934
|
+
for encoder_layer in self.layers:
|
|
935
|
+
hidden_states = encoder_layer(
|
|
1084
936
|
hidden_states,
|
|
1085
937
|
attention_mask,
|
|
1086
|
-
|
|
938
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1087
939
|
reference_points=reference_points,
|
|
1088
940
|
spatial_shapes=spatial_shapes,
|
|
1089
941
|
spatial_shapes_list=spatial_shapes_list,
|
|
1090
942
|
level_start_index=level_start_index,
|
|
1091
|
-
|
|
943
|
+
**kwargs,
|
|
1092
944
|
)
|
|
1093
945
|
|
|
1094
|
-
|
|
946
|
+
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
1095
947
|
|
|
1096
|
-
|
|
1097
|
-
|
|
948
|
+
@staticmethod
|
|
949
|
+
def get_reference_points(spatial_shapes_list, valid_ratios, device):
|
|
950
|
+
"""
|
|
951
|
+
Get reference points for each feature map. Used in decoder.
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
spatial_shapes_list (`list[tuple[int, int]]`):
|
|
955
|
+
Spatial shapes of each feature map.
|
|
956
|
+
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
|
|
957
|
+
Valid ratios of each feature map.
|
|
958
|
+
device (`torch.device`):
|
|
959
|
+
Device on which to create the tensors.
|
|
960
|
+
Returns:
|
|
961
|
+
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
|
|
962
|
+
"""
|
|
963
|
+
reference_points_list = []
|
|
964
|
+
for level, (height, width) in enumerate(spatial_shapes_list):
|
|
965
|
+
ref_y, ref_x = meshgrid(
|
|
966
|
+
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
|
|
967
|
+
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
|
|
968
|
+
indexing="ij",
|
|
969
|
+
)
|
|
970
|
+
# TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
|
|
971
|
+
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
|
|
972
|
+
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width)
|
|
973
|
+
ref = torch.stack((ref_x, ref_y), -1)
|
|
974
|
+
reference_points_list.append(ref)
|
|
975
|
+
reference_points = torch.cat(reference_points_list, 1)
|
|
976
|
+
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
|
977
|
+
return reference_points
|
|
1098
978
|
|
|
1099
|
-
if output_hidden_states:
|
|
1100
|
-
encoder_states = encoder_states + (hidden_states,)
|
|
1101
979
|
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
attentions=all_attentions,
|
|
1108
|
-
)
|
|
980
|
+
def inverse_sigmoid(x, eps=1e-5):
|
|
981
|
+
x = x.clamp(min=0, max=1)
|
|
982
|
+
x1 = x.clamp(min=eps)
|
|
983
|
+
x2 = (1 - x).clamp(min=eps)
|
|
984
|
+
return torch.log(x1 / x2)
|
|
1109
985
|
|
|
1110
986
|
|
|
1111
987
|
class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
|
@@ -1123,12 +999,19 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
|
|
1123
999
|
config: DeformableDetrConfig
|
|
1124
1000
|
"""
|
|
1125
1001
|
|
|
1002
|
+
_can_record_outputs = {
|
|
1003
|
+
"hidden_states": DeformableDetrDecoderLayer,
|
|
1004
|
+
"attentions": OutputRecorder(DeformableDetrSelfAttention, layer_name="self_attn", index=1),
|
|
1005
|
+
"cross_attentions": OutputRecorder(
|
|
1006
|
+
DeformableDetrMultiscaleDeformableAttention, layer_name="encoder_attn", index=1
|
|
1007
|
+
),
|
|
1008
|
+
}
|
|
1009
|
+
|
|
1126
1010
|
def __init__(self, config: DeformableDetrConfig):
|
|
1127
1011
|
super().__init__(config)
|
|
1128
1012
|
|
|
1129
1013
|
self.dropout = config.dropout
|
|
1130
1014
|
self.layers = nn.ModuleList([DeformableDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
|
|
1131
|
-
self.gradient_checkpointing = False
|
|
1132
1015
|
|
|
1133
1016
|
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
|
|
1134
1017
|
self.bbox_embed = None
|
|
@@ -1137,21 +1020,19 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
|
|
1137
1020
|
# Initialize weights and apply final processing
|
|
1138
1021
|
self.post_init()
|
|
1139
1022
|
|
|
1023
|
+
@check_model_inputs()
|
|
1140
1024
|
def forward(
|
|
1141
1025
|
self,
|
|
1142
1026
|
inputs_embeds=None,
|
|
1143
1027
|
encoder_hidden_states=None,
|
|
1144
1028
|
encoder_attention_mask=None,
|
|
1145
|
-
|
|
1029
|
+
object_queries_position_embeddings=None,
|
|
1146
1030
|
reference_points=None,
|
|
1147
1031
|
spatial_shapes=None,
|
|
1148
1032
|
spatial_shapes_list=None,
|
|
1149
1033
|
level_start_index=None,
|
|
1150
1034
|
valid_ratios=None,
|
|
1151
|
-
|
|
1152
|
-
output_hidden_states=None,
|
|
1153
|
-
return_dict=None,
|
|
1154
|
-
**kwargs,
|
|
1035
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1155
1036
|
):
|
|
1156
1037
|
r"""
|
|
1157
1038
|
Args:
|
|
@@ -1165,8 +1046,8 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
|
|
1165
1046
|
in `[0, 1]`:
|
|
1166
1047
|
- 1 for pixels that are real (i.e. **not masked**),
|
|
1167
1048
|
- 0 for pixels that are padding (i.e. **masked**).
|
|
1168
|
-
|
|
1169
|
-
Position embeddings that are added to the queries and keys in each self-attention layer.
|
|
1049
|
+
object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1050
|
+
Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer.
|
|
1170
1051
|
reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
|
|
1171
1052
|
Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
|
|
1172
1053
|
spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
|
|
@@ -1176,28 +1057,11 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
|
|
1176
1057
|
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
|
|
1177
1058
|
Ratio of valid area in each feature level.
|
|
1178
1059
|
|
|
1179
|
-
output_attentions (`bool`, *optional*):
|
|
1180
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
1181
|
-
returned tensors for more detail.
|
|
1182
|
-
output_hidden_states (`bool`, *optional*):
|
|
1183
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
1184
|
-
for more detail.
|
|
1185
|
-
return_dict (`bool`, *optional*):
|
|
1186
|
-
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
1187
1060
|
"""
|
|
1188
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1189
|
-
output_hidden_states = (
|
|
1190
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1191
|
-
)
|
|
1192
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1193
|
-
|
|
1194
1061
|
if inputs_embeds is not None:
|
|
1195
1062
|
hidden_states = inputs_embeds
|
|
1196
1063
|
|
|
1197
1064
|
# decoder layers
|
|
1198
|
-
all_hidden_states = () if output_hidden_states else None
|
|
1199
|
-
all_self_attns = () if output_attentions else None
|
|
1200
|
-
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
1201
1065
|
intermediate = ()
|
|
1202
1066
|
intermediate_reference_points = ()
|
|
1203
1067
|
|
|
@@ -1212,23 +1076,18 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
|
|
1212
1076
|
else:
|
|
1213
1077
|
raise ValueError("Reference points' last dimension must be of size 2")
|
|
1214
1078
|
|
|
1215
|
-
|
|
1216
|
-
all_hidden_states += (hidden_states,)
|
|
1217
|
-
|
|
1218
|
-
layer_outputs = decoder_layer(
|
|
1079
|
+
hidden_states = decoder_layer(
|
|
1219
1080
|
hidden_states,
|
|
1220
|
-
|
|
1081
|
+
object_queries_position_embeddings,
|
|
1221
1082
|
reference_points_input,
|
|
1222
1083
|
spatial_shapes,
|
|
1223
1084
|
spatial_shapes_list,
|
|
1224
1085
|
level_start_index,
|
|
1225
1086
|
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
|
1226
1087
|
encoder_attention_mask,
|
|
1227
|
-
|
|
1088
|
+
**kwargs,
|
|
1228
1089
|
)
|
|
1229
1090
|
|
|
1230
|
-
hidden_states = layer_outputs[0]
|
|
1231
|
-
|
|
1232
1091
|
# hack implementation for iterative bounding box refinement
|
|
1233
1092
|
if self.bbox_embed is not None:
|
|
1234
1093
|
tmp = self.bbox_embed[idx](hidden_states)
|
|
@@ -1249,40 +1108,14 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
|
|
1249
1108
|
intermediate += (hidden_states,)
|
|
1250
1109
|
intermediate_reference_points += (reference_points,)
|
|
1251
1110
|
|
|
1252
|
-
if output_attentions:
|
|
1253
|
-
all_self_attns += (layer_outputs[1],)
|
|
1254
|
-
|
|
1255
|
-
if encoder_hidden_states is not None:
|
|
1256
|
-
all_cross_attentions += (layer_outputs[2],)
|
|
1257
|
-
|
|
1258
1111
|
# Keep batch_size as first dimension
|
|
1259
1112
|
intermediate = torch.stack(intermediate, dim=1)
|
|
1260
1113
|
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
|
|
1261
1114
|
|
|
1262
|
-
# add hidden states from the last decoder layer
|
|
1263
|
-
if output_hidden_states:
|
|
1264
|
-
all_hidden_states += (hidden_states,)
|
|
1265
|
-
|
|
1266
|
-
if not return_dict:
|
|
1267
|
-
return tuple(
|
|
1268
|
-
v
|
|
1269
|
-
for v in [
|
|
1270
|
-
hidden_states,
|
|
1271
|
-
intermediate,
|
|
1272
|
-
intermediate_reference_points,
|
|
1273
|
-
all_hidden_states,
|
|
1274
|
-
all_self_attns,
|
|
1275
|
-
all_cross_attentions,
|
|
1276
|
-
]
|
|
1277
|
-
if v is not None
|
|
1278
|
-
)
|
|
1279
1115
|
return DeformableDetrDecoderOutput(
|
|
1280
1116
|
last_hidden_state=hidden_states,
|
|
1281
1117
|
intermediate_hidden_states=intermediate,
|
|
1282
1118
|
intermediate_reference_points=intermediate_reference_points,
|
|
1283
|
-
hidden_states=all_hidden_states,
|
|
1284
|
-
attentions=all_self_attns,
|
|
1285
|
-
cross_attentions=all_cross_attentions,
|
|
1286
1119
|
)
|
|
1287
1120
|
|
|
1288
1121
|
|
|
@@ -1296,17 +1129,23 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1296
1129
|
def __init__(self, config: DeformableDetrConfig):
|
|
1297
1130
|
super().__init__(config)
|
|
1298
1131
|
|
|
1299
|
-
# Create backbone
|
|
1300
|
-
backbone = DeformableDetrConvEncoder(config)
|
|
1301
|
-
|
|
1302
|
-
|
|
1132
|
+
# Create backbone
|
|
1133
|
+
self.backbone = DeformableDetrConvEncoder(config)
|
|
1134
|
+
|
|
1135
|
+
# Create positional encoding
|
|
1136
|
+
if config.position_embedding_type == "sine":
|
|
1137
|
+
self.position_embedding = DeformableDetrSinePositionEmbedding(config.d_model // 2, normalize=True)
|
|
1138
|
+
elif config.position_embedding_type == "learned":
|
|
1139
|
+
self.position_embedding = DeformableDetrLearnedPositionEmbedding(config.d_model // 2)
|
|
1140
|
+
else:
|
|
1141
|
+
raise ValueError(f"Not supported {config.position_embedding_type}")
|
|
1303
1142
|
|
|
1304
1143
|
# Create input projection layers
|
|
1305
1144
|
if config.num_feature_levels > 1:
|
|
1306
|
-
num_backbone_outs = len(backbone.intermediate_channel_sizes)
|
|
1145
|
+
num_backbone_outs = len(self.backbone.intermediate_channel_sizes)
|
|
1307
1146
|
input_proj_list = []
|
|
1308
1147
|
for _ in range(num_backbone_outs):
|
|
1309
|
-
in_channels = backbone.intermediate_channel_sizes[_]
|
|
1148
|
+
in_channels = self.backbone.intermediate_channel_sizes[_]
|
|
1310
1149
|
input_proj_list.append(
|
|
1311
1150
|
nn.Sequential(
|
|
1312
1151
|
nn.Conv2d(in_channels, config.d_model, kernel_size=1),
|
|
@@ -1333,7 +1172,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1333
1172
|
[
|
|
1334
1173
|
nn.Sequential(
|
|
1335
1174
|
nn.Conv2d(
|
|
1336
|
-
backbone.intermediate_channel_sizes[-1],
|
|
1175
|
+
self.backbone.intermediate_channel_sizes[-1],
|
|
1337
1176
|
config.d_model,
|
|
1338
1177
|
kernel_size=1,
|
|
1339
1178
|
),
|
|
@@ -1361,11 +1200,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1361
1200
|
self.post_init()
|
|
1362
1201
|
|
|
1363
1202
|
def freeze_backbone(self):
|
|
1364
|
-
for name, param in self.backbone.
|
|
1203
|
+
for name, param in self.backbone.model.named_parameters():
|
|
1365
1204
|
param.requires_grad_(False)
|
|
1366
1205
|
|
|
1367
1206
|
def unfreeze_backbone(self):
|
|
1368
|
-
for name, param in self.backbone.
|
|
1207
|
+
for name, param in self.backbone.model.named_parameters():
|
|
1369
1208
|
param.requires_grad_(True)
|
|
1370
1209
|
|
|
1371
1210
|
def get_valid_ratio(self, mask, dtype=torch.float32):
|
|
@@ -1386,15 +1225,18 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1386
1225
|
temperature = 10000
|
|
1387
1226
|
scale = 2 * math.pi
|
|
1388
1227
|
|
|
1389
|
-
|
|
1228
|
+
# Compute position embeddings in float32 to avoid overflow with large temperature values in fp16
|
|
1229
|
+
proposals_dtype = proposals.dtype
|
|
1230
|
+
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
|
|
1390
1231
|
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
|
|
1391
1232
|
# batch_size, num_queries, 4
|
|
1392
|
-
proposals = proposals.sigmoid() * scale
|
|
1233
|
+
proposals = proposals.sigmoid().to(torch.float32) * scale
|
|
1393
1234
|
# batch_size, num_queries, 4, 128
|
|
1394
1235
|
pos = proposals[:, :, :, None] / dim_t
|
|
1395
1236
|
# batch_size, num_queries, 4, 64, 2 -> batch_size, num_queries, 512
|
|
1396
1237
|
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
|
|
1397
|
-
|
|
1238
|
+
# Convert back to target dtype after all computations are done
|
|
1239
|
+
return pos.to(proposals_dtype)
|
|
1398
1240
|
|
|
1399
1241
|
def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes):
|
|
1400
1242
|
"""Generate the encoder output proposals from encoded enc_output.
|
|
@@ -1458,6 +1300,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1458
1300
|
return object_query, output_proposals
|
|
1459
1301
|
|
|
1460
1302
|
@auto_docstring
|
|
1303
|
+
@can_return_tuple
|
|
1461
1304
|
def forward(
|
|
1462
1305
|
self,
|
|
1463
1306
|
pixel_values: torch.FloatTensor,
|
|
@@ -1466,10 +1309,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1466
1309
|
encoder_outputs: torch.FloatTensor | None = None,
|
|
1467
1310
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1468
1311
|
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1469
|
-
|
|
1470
|
-
output_hidden_states: bool | None = None,
|
|
1471
|
-
return_dict: bool | None = None,
|
|
1472
|
-
**kwargs,
|
|
1312
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1473
1313
|
) -> tuple[torch.FloatTensor] | DeformableDetrModelOutput:
|
|
1474
1314
|
r"""
|
|
1475
1315
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
@@ -1502,12 +1342,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1502
1342
|
>>> list(last_hidden_states.shape)
|
|
1503
1343
|
[1, 300, 256]
|
|
1504
1344
|
```"""
|
|
1505
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1506
|
-
output_hidden_states = (
|
|
1507
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1508
|
-
)
|
|
1509
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1510
|
-
|
|
1511
1345
|
batch_size, num_channels, height, width = pixel_values.shape
|
|
1512
1346
|
device = pixel_values.device
|
|
1513
1347
|
|
|
@@ -1517,16 +1351,22 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1517
1351
|
# Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper)
|
|
1518
1352
|
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
|
|
1519
1353
|
# which is a list of tuples
|
|
1520
|
-
features
|
|
1354
|
+
features = self.backbone(pixel_values, pixel_mask)
|
|
1521
1355
|
|
|
1522
1356
|
# Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
|
1523
1357
|
sources = []
|
|
1524
1358
|
masks = []
|
|
1359
|
+
position_embeddings_list = []
|
|
1525
1360
|
for level, (source, mask) in enumerate(features):
|
|
1526
1361
|
sources.append(self.input_proj[level](source))
|
|
1527
1362
|
masks.append(mask)
|
|
1528
1363
|
if mask is None:
|
|
1529
1364
|
raise ValueError("No attention mask was provided")
|
|
1365
|
+
# Generate position embeddings for this feature level
|
|
1366
|
+
pos = self.position_embedding(shape=source.shape, device=device, dtype=pixel_values.dtype, mask=mask).to(
|
|
1367
|
+
source.dtype
|
|
1368
|
+
)
|
|
1369
|
+
position_embeddings_list.append(pos)
|
|
1530
1370
|
|
|
1531
1371
|
# Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
|
|
1532
1372
|
if self.config.num_feature_levels > len(sources):
|
|
@@ -1539,7 +1379,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1539
1379
|
mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to(
|
|
1540
1380
|
torch.bool
|
|
1541
1381
|
)[0]
|
|
1542
|
-
pos_l = self.
|
|
1382
|
+
pos_l = self.position_embedding(
|
|
1383
|
+
shape=source.shape, device=device, dtype=pixel_values.dtype, mask=mask
|
|
1384
|
+
).to(source.dtype)
|
|
1543
1385
|
sources.append(source)
|
|
1544
1386
|
masks.append(mask)
|
|
1545
1387
|
position_embeddings_list.append(pos_l)
|
|
@@ -1560,7 +1402,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1560
1402
|
spatial_shapes_list.append(spatial_shape)
|
|
1561
1403
|
source = source.flatten(2).transpose(1, 2)
|
|
1562
1404
|
mask = mask.flatten(1)
|
|
1563
|
-
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
|
1564
1405
|
lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1)
|
|
1565
1406
|
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
|
1566
1407
|
source_flatten.append(source)
|
|
@@ -1578,21 +1419,12 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1578
1419
|
encoder_outputs = self.encoder(
|
|
1579
1420
|
inputs_embeds=source_flatten,
|
|
1580
1421
|
attention_mask=mask_flatten,
|
|
1581
|
-
|
|
1422
|
+
spatial_position_embeddings=lvl_pos_embed_flatten,
|
|
1582
1423
|
spatial_shapes=spatial_shapes,
|
|
1583
1424
|
spatial_shapes_list=spatial_shapes_list,
|
|
1584
1425
|
level_start_index=level_start_index,
|
|
1585
1426
|
valid_ratios=valid_ratios,
|
|
1586
|
-
|
|
1587
|
-
output_hidden_states=output_hidden_states,
|
|
1588
|
-
return_dict=return_dict,
|
|
1589
|
-
)
|
|
1590
|
-
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
|
1591
|
-
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
|
1592
|
-
encoder_outputs = BaseModelOutput(
|
|
1593
|
-
last_hidden_state=encoder_outputs[0],
|
|
1594
|
-
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1595
|
-
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1427
|
+
**kwargs,
|
|
1596
1428
|
)
|
|
1597
1429
|
|
|
1598
1430
|
# Fifth, prepare decoder inputs
|
|
@@ -1635,7 +1467,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1635
1467
|
|
|
1636
1468
|
decoder_outputs = self.decoder(
|
|
1637
1469
|
inputs_embeds=target,
|
|
1638
|
-
|
|
1470
|
+
object_queries_position_embeddings=query_embed,
|
|
1639
1471
|
encoder_hidden_states=encoder_outputs[0],
|
|
1640
1472
|
encoder_attention_mask=mask_flatten,
|
|
1641
1473
|
reference_points=reference_points,
|
|
@@ -1643,17 +1475,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1643
1475
|
spatial_shapes_list=spatial_shapes_list,
|
|
1644
1476
|
level_start_index=level_start_index,
|
|
1645
1477
|
valid_ratios=valid_ratios,
|
|
1646
|
-
|
|
1647
|
-
output_hidden_states=output_hidden_states,
|
|
1648
|
-
return_dict=return_dict,
|
|
1478
|
+
**kwargs,
|
|
1649
1479
|
)
|
|
1650
1480
|
|
|
1651
|
-
if not return_dict:
|
|
1652
|
-
enc_outputs = tuple(value for value in [enc_outputs_class, enc_outputs_coord_logits] if value is not None)
|
|
1653
|
-
tuple_outputs = (init_reference_points,) + decoder_outputs + encoder_outputs + enc_outputs
|
|
1654
|
-
|
|
1655
|
-
return tuple_outputs
|
|
1656
|
-
|
|
1657
1481
|
return DeformableDetrModelOutput(
|
|
1658
1482
|
init_reference_points=init_reference_points,
|
|
1659
1483
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
@@ -1670,14 +1494,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|
|
1670
1494
|
)
|
|
1671
1495
|
|
|
1672
1496
|
|
|
1673
|
-
# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead
|
|
1674
1497
|
class DeformableDetrMLPPredictionHead(nn.Module):
|
|
1675
1498
|
"""
|
|
1676
1499
|
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
|
1677
1500
|
height and width of a bounding box w.r.t. an image.
|
|
1678
1501
|
|
|
1679
|
-
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
1680
|
-
|
|
1681
1502
|
"""
|
|
1682
1503
|
|
|
1683
1504
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
@@ -1726,15 +1547,18 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|
|
1726
1547
|
for _ in range(num_pred)
|
|
1727
1548
|
]
|
|
1728
1549
|
)
|
|
1550
|
+
# Convert to instance attribute before modifying
|
|
1551
|
+
self._tied_weights_keys = self._tied_weights_keys.copy()
|
|
1729
1552
|
if config.with_box_refine:
|
|
1730
1553
|
self.model.decoder.bbox_embed = self.bbox_embed
|
|
1731
|
-
self._tied_weights_keys["
|
|
1554
|
+
self._tied_weights_keys["bbox_embed"] = "model.decoder.bbox_embed"
|
|
1732
1555
|
if config.two_stage:
|
|
1733
1556
|
self.model.decoder.class_embed = self.class_embed
|
|
1734
|
-
self._tied_weights_keys["
|
|
1557
|
+
self._tied_weights_keys["class_embed"] = "model.decoder.class_embed"
|
|
1735
1558
|
self.post_init()
|
|
1736
1559
|
|
|
1737
1560
|
@auto_docstring
|
|
1561
|
+
@can_return_tuple
|
|
1738
1562
|
def forward(
|
|
1739
1563
|
self,
|
|
1740
1564
|
pixel_values: torch.FloatTensor,
|
|
@@ -1744,10 +1568,7 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|
|
1744
1568
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1745
1569
|
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1746
1570
|
labels: list[dict] | None = None,
|
|
1747
|
-
|
|
1748
|
-
output_hidden_states: bool | None = None,
|
|
1749
|
-
return_dict: bool | None = None,
|
|
1750
|
-
**kwargs,
|
|
1571
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1751
1572
|
) -> tuple[torch.FloatTensor] | DeformableDetrObjectDetectionOutput:
|
|
1752
1573
|
r"""
|
|
1753
1574
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
@@ -1795,8 +1616,6 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|
|
1795
1616
|
Detected cat with confidence 0.789 at location [342.19, 24.3, 640.02, 372.25]
|
|
1796
1617
|
Detected remote with confidence 0.633 at location [40.79, 72.78, 176.76, 117.25]
|
|
1797
1618
|
```"""
|
|
1798
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1799
|
-
|
|
1800
1619
|
# First, sent images through DETR base model to obtain encoder + decoder outputs
|
|
1801
1620
|
outputs = self.model(
|
|
1802
1621
|
pixel_values,
|
|
@@ -1805,14 +1624,12 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|
|
1805
1624
|
encoder_outputs=encoder_outputs,
|
|
1806
1625
|
inputs_embeds=inputs_embeds,
|
|
1807
1626
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
1808
|
-
|
|
1809
|
-
output_hidden_states=output_hidden_states,
|
|
1810
|
-
return_dict=return_dict,
|
|
1627
|
+
**kwargs,
|
|
1811
1628
|
)
|
|
1812
1629
|
|
|
1813
|
-
hidden_states = outputs.intermediate_hidden_states
|
|
1814
|
-
init_reference = outputs.init_reference_points
|
|
1815
|
-
inter_references = outputs.intermediate_reference_points
|
|
1630
|
+
hidden_states = outputs.intermediate_hidden_states
|
|
1631
|
+
init_reference = outputs.init_reference_points
|
|
1632
|
+
inter_references = outputs.intermediate_reference_points
|
|
1816
1633
|
|
|
1817
1634
|
# class logits + predicted bounding boxes
|
|
1818
1635
|
outputs_classes = []
|
|
@@ -1853,16 +1670,8 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|
|
1853
1670
|
outputs_class,
|
|
1854
1671
|
outputs_coord,
|
|
1855
1672
|
)
|
|
1856
|
-
if not return_dict:
|
|
1857
|
-
if auxiliary_outputs is not None:
|
|
1858
|
-
output = (logits, pred_boxes) + auxiliary_outputs + outputs
|
|
1859
|
-
else:
|
|
1860
|
-
output = (logits, pred_boxes) + outputs
|
|
1861
|
-
tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output
|
|
1862
|
-
|
|
1863
|
-
return tuple_outputs
|
|
1864
1673
|
|
|
1865
|
-
|
|
1674
|
+
return DeformableDetrObjectDetectionOutput(
|
|
1866
1675
|
loss=loss,
|
|
1867
1676
|
loss_dict=loss_dict,
|
|
1868
1677
|
logits=logits,
|
|
@@ -1882,11 +1691,5 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|
|
1882
1691
|
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
|
|
1883
1692
|
)
|
|
1884
1693
|
|
|
1885
|
-
return dict_outputs
|
|
1886
|
-
|
|
1887
1694
|
|
|
1888
|
-
__all__ = [
|
|
1889
|
-
"DeformableDetrForObjectDetection",
|
|
1890
|
-
"DeformableDetrModel",
|
|
1891
|
-
"DeformableDetrPreTrainedModel",
|
|
1892
|
-
]
|
|
1695
|
+
__all__ = ["DeformableDetrForObjectDetection", "DeformableDetrModel", "DeformableDetrPreTrainedModel"]
|