transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +11 -37
- transformers/activations.py +2 -2
- transformers/audio_utils.py +32 -32
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +26 -126
- transformers/cli/chat.py +3 -3
- transformers/cli/serve.py +13 -10
- transformers/cli/transformers.py +2 -1
- transformers/configuration_utils.py +22 -92
- transformers/conversion_mapping.py +150 -26
- transformers/convert_slow_tokenizer.py +9 -12
- transformers/core_model_loading.py +217 -129
- transformers/data/processors/glue.py +0 -1
- transformers/data/processors/utils.py +0 -1
- transformers/data/processors/xnli.py +0 -1
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +10 -11
- transformers/distributed/configuration_utils.py +1 -2
- transformers/dynamic_module_utils.py +23 -23
- transformers/feature_extraction_sequence_utils.py +19 -23
- transformers/feature_extraction_utils.py +14 -14
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +2 -4
- transformers/generation/configuration_utils.py +54 -39
- transformers/generation/continuous_batching/__init__.py +0 -1
- transformers/generation/continuous_batching/cache.py +74 -44
- transformers/generation/continuous_batching/cache_manager.py +28 -28
- transformers/generation/continuous_batching/continuous_api.py +133 -414
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +77 -19
- transformers/generation/continuous_batching/scheduler.py +154 -104
- transformers/generation/logits_process.py +10 -133
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/streamers.py +0 -1
- transformers/generation/utils.py +91 -121
- transformers/generation/watermarking.py +2 -3
- transformers/hf_argparser.py +9 -13
- transformers/hyperparameter_search.py +1 -2
- transformers/image_processing_base.py +9 -9
- transformers/image_processing_utils.py +11 -15
- transformers/image_processing_utils_fast.py +70 -71
- transformers/image_transforms.py +73 -42
- transformers/image_utils.py +30 -37
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/awq.py +1 -3
- transformers/integrations/deepspeed.py +146 -4
- transformers/integrations/eetq.py +0 -1
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/fbgemm_fp8.py +1 -2
- transformers/integrations/finegrained_fp8.py +149 -13
- transformers/integrations/flash_attention.py +3 -8
- transformers/integrations/flex_attention.py +1 -1
- transformers/integrations/fp_quant.py +4 -6
- transformers/integrations/ggml.py +0 -1
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/integration_utils.py +2 -3
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +52 -40
- transformers/integrations/peft.py +488 -176
- transformers/integrations/quark.py +2 -4
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/integrations/torchao.py +4 -6
- transformers/loss/loss_lw_detr.py +356 -0
- transformers/loss/loss_utils.py +2 -0
- transformers/masking_utils.py +199 -59
- transformers/model_debugging_utils.py +4 -5
- transformers/modelcard.py +14 -192
- transformers/modeling_attn_mask_utils.py +19 -19
- transformers/modeling_flash_attention_utils.py +28 -29
- transformers/modeling_gguf_pytorch_utils.py +5 -5
- transformers/modeling_layers.py +21 -22
- transformers/modeling_outputs.py +242 -253
- transformers/modeling_rope_utils.py +32 -32
- transformers/modeling_utils.py +416 -438
- transformers/models/__init__.py +10 -0
- transformers/models/afmoe/configuration_afmoe.py +40 -33
- transformers/models/afmoe/modeling_afmoe.py +38 -41
- transformers/models/afmoe/modular_afmoe.py +23 -25
- transformers/models/aimv2/configuration_aimv2.py +2 -10
- transformers/models/aimv2/modeling_aimv2.py +46 -45
- transformers/models/aimv2/modular_aimv2.py +13 -19
- transformers/models/albert/configuration_albert.py +8 -2
- transformers/models/albert/modeling_albert.py +70 -72
- transformers/models/albert/tokenization_albert.py +1 -4
- transformers/models/align/configuration_align.py +8 -6
- transformers/models/align/modeling_align.py +83 -86
- transformers/models/align/processing_align.py +2 -30
- transformers/models/altclip/configuration_altclip.py +4 -7
- transformers/models/altclip/modeling_altclip.py +106 -103
- transformers/models/altclip/processing_altclip.py +2 -15
- transformers/models/apertus/__init__.py +0 -1
- transformers/models/apertus/configuration_apertus.py +23 -28
- transformers/models/apertus/modeling_apertus.py +35 -38
- transformers/models/apertus/modular_apertus.py +36 -40
- transformers/models/arcee/configuration_arcee.py +25 -30
- transformers/models/arcee/modeling_arcee.py +35 -38
- transformers/models/arcee/modular_arcee.py +20 -23
- transformers/models/aria/configuration_aria.py +31 -44
- transformers/models/aria/image_processing_aria.py +25 -27
- transformers/models/aria/modeling_aria.py +102 -102
- transformers/models/aria/modular_aria.py +111 -124
- transformers/models/aria/processing_aria.py +28 -35
- transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
- transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
- transformers/models/audioflamingo3/__init__.py +0 -1
- transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
- transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
- transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
- transformers/models/auto/auto_factory.py +12 -11
- transformers/models/auto/configuration_auto.py +48 -5
- transformers/models/auto/feature_extraction_auto.py +5 -7
- transformers/models/auto/image_processing_auto.py +30 -39
- transformers/models/auto/modeling_auto.py +33 -199
- transformers/models/auto/processing_auto.py +11 -19
- transformers/models/auto/tokenization_auto.py +38 -37
- transformers/models/auto/video_processing_auto.py +7 -8
- transformers/models/autoformer/configuration_autoformer.py +4 -7
- transformers/models/autoformer/modeling_autoformer.py +100 -101
- transformers/models/aya_vision/configuration_aya_vision.py +4 -1
- transformers/models/aya_vision/modeling_aya_vision.py +64 -99
- transformers/models/aya_vision/modular_aya_vision.py +46 -74
- transformers/models/aya_vision/processing_aya_vision.py +25 -53
- transformers/models/bamba/configuration_bamba.py +46 -39
- transformers/models/bamba/modeling_bamba.py +83 -119
- transformers/models/bamba/modular_bamba.py +70 -109
- transformers/models/bark/configuration_bark.py +6 -8
- transformers/models/bark/generation_configuration_bark.py +3 -5
- transformers/models/bark/modeling_bark.py +64 -65
- transformers/models/bark/processing_bark.py +19 -41
- transformers/models/bart/configuration_bart.py +9 -5
- transformers/models/bart/modeling_bart.py +124 -129
- transformers/models/barthez/tokenization_barthez.py +1 -4
- transformers/models/bartpho/tokenization_bartpho.py +6 -7
- transformers/models/beit/configuration_beit.py +2 -15
- transformers/models/beit/image_processing_beit.py +53 -56
- transformers/models/beit/image_processing_beit_fast.py +11 -12
- transformers/models/beit/modeling_beit.py +65 -62
- transformers/models/bert/configuration_bert.py +12 -2
- transformers/models/bert/modeling_bert.py +117 -152
- transformers/models/bert/tokenization_bert.py +2 -4
- transformers/models/bert/tokenization_bert_legacy.py +3 -5
- transformers/models/bert_generation/configuration_bert_generation.py +17 -2
- transformers/models/bert_generation/modeling_bert_generation.py +53 -55
- transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
- transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
- transformers/models/bertweet/tokenization_bertweet.py +1 -3
- transformers/models/big_bird/configuration_big_bird.py +12 -9
- transformers/models/big_bird/modeling_big_bird.py +107 -124
- transformers/models/big_bird/tokenization_big_bird.py +1 -4
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
- transformers/models/biogpt/configuration_biogpt.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +73 -79
- transformers/models/biogpt/modular_biogpt.py +60 -66
- transformers/models/biogpt/tokenization_biogpt.py +3 -5
- transformers/models/bit/configuration_bit.py +2 -5
- transformers/models/bit/image_processing_bit.py +21 -24
- transformers/models/bit/image_processing_bit_fast.py +0 -1
- transformers/models/bit/modeling_bit.py +15 -16
- transformers/models/bitnet/configuration_bitnet.py +23 -28
- transformers/models/bitnet/modeling_bitnet.py +34 -38
- transformers/models/bitnet/modular_bitnet.py +7 -10
- transformers/models/blenderbot/configuration_blenderbot.py +8 -5
- transformers/models/blenderbot/modeling_blenderbot.py +68 -99
- transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
- transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
- transformers/models/blip/configuration_blip.py +9 -10
- transformers/models/blip/image_processing_blip.py +17 -20
- transformers/models/blip/image_processing_blip_fast.py +0 -1
- transformers/models/blip/modeling_blip.py +115 -108
- transformers/models/blip/modeling_blip_text.py +63 -65
- transformers/models/blip/processing_blip.py +5 -36
- transformers/models/blip_2/configuration_blip_2.py +2 -2
- transformers/models/blip_2/modeling_blip_2.py +145 -121
- transformers/models/blip_2/processing_blip_2.py +8 -38
- transformers/models/bloom/configuration_bloom.py +5 -2
- transformers/models/bloom/modeling_bloom.py +60 -60
- transformers/models/blt/configuration_blt.py +94 -86
- transformers/models/blt/modeling_blt.py +93 -90
- transformers/models/blt/modular_blt.py +127 -69
- transformers/models/bridgetower/configuration_bridgetower.py +7 -2
- transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
- transformers/models/bridgetower/modeling_bridgetower.py +136 -124
- transformers/models/bridgetower/processing_bridgetower.py +2 -16
- transformers/models/bros/configuration_bros.py +24 -18
- transformers/models/bros/modeling_bros.py +78 -80
- transformers/models/bros/processing_bros.py +2 -12
- transformers/models/byt5/tokenization_byt5.py +4 -6
- transformers/models/camembert/configuration_camembert.py +8 -2
- transformers/models/camembert/modeling_camembert.py +97 -99
- transformers/models/camembert/modular_camembert.py +51 -54
- transformers/models/camembert/tokenization_camembert.py +1 -4
- transformers/models/canine/configuration_canine.py +4 -2
- transformers/models/canine/modeling_canine.py +73 -75
- transformers/models/canine/tokenization_canine.py +0 -1
- transformers/models/chameleon/configuration_chameleon.py +29 -34
- transformers/models/chameleon/image_processing_chameleon.py +21 -24
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
- transformers/models/chameleon/modeling_chameleon.py +135 -92
- transformers/models/chameleon/processing_chameleon.py +16 -41
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
- transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
- transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
- transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
- transformers/models/clap/configuration_clap.py +4 -9
- transformers/models/clap/feature_extraction_clap.py +9 -10
- transformers/models/clap/modeling_clap.py +109 -111
- transformers/models/clap/processing_clap.py +2 -15
- transformers/models/clip/configuration_clip.py +4 -2
- transformers/models/clip/image_processing_clip.py +21 -24
- transformers/models/clip/image_processing_clip_fast.py +9 -1
- transformers/models/clip/modeling_clip.py +70 -68
- transformers/models/clip/processing_clip.py +2 -14
- transformers/models/clip/tokenization_clip.py +2 -5
- transformers/models/clipseg/configuration_clipseg.py +4 -2
- transformers/models/clipseg/modeling_clipseg.py +113 -112
- transformers/models/clipseg/processing_clipseg.py +19 -42
- transformers/models/clvp/configuration_clvp.py +15 -5
- transformers/models/clvp/feature_extraction_clvp.py +7 -10
- transformers/models/clvp/modeling_clvp.py +138 -145
- transformers/models/clvp/number_normalizer.py +1 -2
- transformers/models/clvp/processing_clvp.py +3 -20
- transformers/models/clvp/tokenization_clvp.py +0 -1
- transformers/models/code_llama/tokenization_code_llama.py +3 -6
- transformers/models/codegen/configuration_codegen.py +4 -4
- transformers/models/codegen/modeling_codegen.py +50 -49
- transformers/models/codegen/tokenization_codegen.py +5 -6
- transformers/models/cohere/configuration_cohere.py +25 -30
- transformers/models/cohere/modeling_cohere.py +39 -42
- transformers/models/cohere/modular_cohere.py +27 -31
- transformers/models/cohere/tokenization_cohere.py +5 -6
- transformers/models/cohere2/configuration_cohere2.py +27 -32
- transformers/models/cohere2/modeling_cohere2.py +38 -41
- transformers/models/cohere2/modular_cohere2.py +48 -52
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
- transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
- transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
- transformers/models/colpali/configuration_colpali.py +0 -1
- transformers/models/colpali/modeling_colpali.py +14 -16
- transformers/models/colpali/modular_colpali.py +11 -51
- transformers/models/colpali/processing_colpali.py +14 -52
- transformers/models/colqwen2/modeling_colqwen2.py +27 -28
- transformers/models/colqwen2/modular_colqwen2.py +36 -74
- transformers/models/colqwen2/processing_colqwen2.py +16 -52
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
- transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
- transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
- transformers/models/convbert/configuration_convbert.py +11 -8
- transformers/models/convbert/modeling_convbert.py +85 -87
- transformers/models/convbert/tokenization_convbert.py +0 -1
- transformers/models/convnext/configuration_convnext.py +2 -5
- transformers/models/convnext/image_processing_convnext.py +18 -21
- transformers/models/convnext/image_processing_convnext_fast.py +7 -8
- transformers/models/convnext/modeling_convnext.py +12 -14
- transformers/models/convnextv2/configuration_convnextv2.py +2 -5
- transformers/models/convnextv2/modeling_convnextv2.py +12 -14
- transformers/models/cpm/tokenization_cpm.py +6 -7
- transformers/models/cpm/tokenization_cpm_fast.py +3 -5
- transformers/models/cpmant/configuration_cpmant.py +4 -1
- transformers/models/cpmant/modeling_cpmant.py +38 -40
- transformers/models/cpmant/tokenization_cpmant.py +1 -3
- transformers/models/csm/configuration_csm.py +58 -66
- transformers/models/csm/generation_csm.py +13 -14
- transformers/models/csm/modeling_csm.py +81 -84
- transformers/models/csm/modular_csm.py +56 -58
- transformers/models/csm/processing_csm.py +25 -68
- transformers/models/ctrl/configuration_ctrl.py +16 -1
- transformers/models/ctrl/modeling_ctrl.py +51 -66
- transformers/models/ctrl/tokenization_ctrl.py +0 -1
- transformers/models/cvt/configuration_cvt.py +0 -1
- transformers/models/cvt/modeling_cvt.py +13 -15
- transformers/models/cwm/__init__.py +0 -1
- transformers/models/cwm/configuration_cwm.py +8 -12
- transformers/models/cwm/modeling_cwm.py +36 -38
- transformers/models/cwm/modular_cwm.py +10 -12
- transformers/models/d_fine/configuration_d_fine.py +10 -57
- transformers/models/d_fine/modeling_d_fine.py +786 -927
- transformers/models/d_fine/modular_d_fine.py +339 -417
- transformers/models/dab_detr/configuration_dab_detr.py +22 -49
- transformers/models/dab_detr/modeling_dab_detr.py +79 -77
- transformers/models/dac/configuration_dac.py +0 -1
- transformers/models/dac/feature_extraction_dac.py +6 -9
- transformers/models/dac/modeling_dac.py +22 -24
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
- transformers/models/data2vec/configuration_data2vec_text.py +11 -3
- transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
- transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
- transformers/models/data2vec/modeling_data2vec_text.py +97 -99
- transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
- transformers/models/data2vec/modular_data2vec_audio.py +6 -1
- transformers/models/data2vec/modular_data2vec_text.py +51 -54
- transformers/models/dbrx/configuration_dbrx.py +29 -22
- transformers/models/dbrx/modeling_dbrx.py +45 -48
- transformers/models/dbrx/modular_dbrx.py +37 -39
- transformers/models/deberta/configuration_deberta.py +6 -1
- transformers/models/deberta/modeling_deberta.py +57 -60
- transformers/models/deberta/tokenization_deberta.py +2 -5
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
- transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
- transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
- transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
- transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
- transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
- transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
- transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
- transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
- transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
- transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
- transformers/models/deit/configuration_deit.py +0 -1
- transformers/models/deit/image_processing_deit.py +18 -21
- transformers/models/deit/image_processing_deit_fast.py +0 -1
- transformers/models/deit/modeling_deit.py +27 -25
- transformers/models/depth_anything/configuration_depth_anything.py +12 -43
- transformers/models/depth_anything/modeling_depth_anything.py +10 -11
- transformers/models/depth_pro/configuration_depth_pro.py +0 -1
- transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
- transformers/models/depth_pro/modeling_depth_pro.py +29 -27
- transformers/models/detr/configuration_detr.py +18 -50
- transformers/models/detr/image_processing_detr.py +64 -66
- transformers/models/detr/image_processing_detr_fast.py +33 -34
- transformers/models/detr/modeling_detr.py +748 -789
- transformers/models/dia/configuration_dia.py +9 -15
- transformers/models/dia/feature_extraction_dia.py +6 -9
- transformers/models/dia/generation_dia.py +48 -53
- transformers/models/dia/modeling_dia.py +68 -71
- transformers/models/dia/modular_dia.py +56 -58
- transformers/models/dia/processing_dia.py +39 -29
- transformers/models/dia/tokenization_dia.py +3 -6
- transformers/models/diffllama/configuration_diffllama.py +25 -30
- transformers/models/diffllama/modeling_diffllama.py +45 -53
- transformers/models/diffllama/modular_diffllama.py +18 -25
- transformers/models/dinat/configuration_dinat.py +2 -5
- transformers/models/dinat/modeling_dinat.py +47 -48
- transformers/models/dinov2/configuration_dinov2.py +2 -5
- transformers/models/dinov2/modeling_dinov2.py +20 -21
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
- transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
- transformers/models/distilbert/configuration_distilbert.py +8 -2
- transformers/models/distilbert/modeling_distilbert.py +47 -49
- transformers/models/distilbert/tokenization_distilbert.py +0 -1
- transformers/models/doge/__init__.py +0 -1
- transformers/models/doge/configuration_doge.py +42 -35
- transformers/models/doge/modeling_doge.py +46 -49
- transformers/models/doge/modular_doge.py +77 -68
- transformers/models/donut/configuration_donut_swin.py +0 -1
- transformers/models/donut/image_processing_donut.py +26 -29
- transformers/models/donut/image_processing_donut_fast.py +9 -14
- transformers/models/donut/modeling_donut_swin.py +44 -46
- transformers/models/donut/processing_donut.py +5 -26
- transformers/models/dots1/configuration_dots1.py +43 -36
- transformers/models/dots1/modeling_dots1.py +35 -38
- transformers/models/dots1/modular_dots1.py +0 -1
- transformers/models/dpr/configuration_dpr.py +19 -2
- transformers/models/dpr/modeling_dpr.py +37 -39
- transformers/models/dpr/tokenization_dpr.py +7 -9
- transformers/models/dpr/tokenization_dpr_fast.py +7 -9
- transformers/models/dpt/configuration_dpt.py +23 -66
- transformers/models/dpt/image_processing_dpt.py +65 -66
- transformers/models/dpt/image_processing_dpt_fast.py +18 -19
- transformers/models/dpt/modeling_dpt.py +38 -36
- transformers/models/dpt/modular_dpt.py +14 -15
- transformers/models/edgetam/configuration_edgetam.py +1 -2
- transformers/models/edgetam/modeling_edgetam.py +87 -89
- transformers/models/edgetam/modular_edgetam.py +7 -13
- transformers/models/edgetam_video/__init__.py +0 -1
- transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
- transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
- transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
- transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
- transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
- transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
- transformers/models/efficientnet/configuration_efficientnet.py +0 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
- transformers/models/efficientnet/modeling_efficientnet.py +12 -14
- transformers/models/electra/configuration_electra.py +13 -3
- transformers/models/electra/modeling_electra.py +107 -109
- transformers/models/emu3/configuration_emu3.py +17 -17
- transformers/models/emu3/image_processing_emu3.py +44 -39
- transformers/models/emu3/modeling_emu3.py +143 -109
- transformers/models/emu3/modular_emu3.py +109 -73
- transformers/models/emu3/processing_emu3.py +18 -43
- transformers/models/encodec/configuration_encodec.py +2 -4
- transformers/models/encodec/feature_extraction_encodec.py +10 -13
- transformers/models/encodec/modeling_encodec.py +25 -29
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
- transformers/models/eomt/configuration_eomt.py +12 -14
- transformers/models/eomt/image_processing_eomt.py +53 -55
- transformers/models/eomt/image_processing_eomt_fast.py +18 -19
- transformers/models/eomt/modeling_eomt.py +19 -21
- transformers/models/eomt/modular_eomt.py +28 -30
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -3
- transformers/models/ernie/modeling_ernie.py +127 -162
- transformers/models/ernie/modular_ernie.py +91 -103
- transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
- transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
- transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
- transformers/models/esm/configuration_esm.py +11 -15
- transformers/models/esm/modeling_esm.py +35 -37
- transformers/models/esm/modeling_esmfold.py +43 -50
- transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
- transformers/models/esm/openfold_utils/loss.py +1 -2
- transformers/models/esm/openfold_utils/protein.py +15 -16
- transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
- transformers/models/esm/tokenization_esm.py +2 -4
- transformers/models/evolla/configuration_evolla.py +50 -40
- transformers/models/evolla/modeling_evolla.py +69 -68
- transformers/models/evolla/modular_evolla.py +50 -48
- transformers/models/evolla/processing_evolla.py +23 -35
- transformers/models/exaone4/configuration_exaone4.py +27 -27
- transformers/models/exaone4/modeling_exaone4.py +36 -39
- transformers/models/exaone4/modular_exaone4.py +51 -50
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +31 -26
- transformers/models/falcon/modeling_falcon.py +76 -84
- transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
- transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
- transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
- transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
- transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
- transformers/models/flaubert/configuration_flaubert.py +10 -5
- transformers/models/flaubert/modeling_flaubert.py +125 -129
- transformers/models/flaubert/tokenization_flaubert.py +3 -5
- transformers/models/flava/configuration_flava.py +9 -9
- transformers/models/flava/image_processing_flava.py +66 -67
- transformers/models/flava/image_processing_flava_fast.py +46 -47
- transformers/models/flava/modeling_flava.py +144 -135
- transformers/models/flava/processing_flava.py +2 -12
- transformers/models/flex_olmo/__init__.py +0 -1
- transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
- transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
- transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
- transformers/models/florence2/configuration_florence2.py +4 -1
- transformers/models/florence2/modeling_florence2.py +96 -72
- transformers/models/florence2/modular_florence2.py +100 -107
- transformers/models/florence2/processing_florence2.py +18 -47
- transformers/models/fnet/configuration_fnet.py +6 -2
- transformers/models/fnet/modeling_fnet.py +69 -80
- transformers/models/fnet/tokenization_fnet.py +0 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -5
- transformers/models/focalnet/modeling_focalnet.py +49 -48
- transformers/models/fsmt/configuration_fsmt.py +12 -17
- transformers/models/fsmt/modeling_fsmt.py +47 -48
- transformers/models/fsmt/tokenization_fsmt.py +3 -5
- transformers/models/funnel/configuration_funnel.py +8 -1
- transformers/models/funnel/modeling_funnel.py +91 -93
- transformers/models/funnel/tokenization_funnel.py +2 -5
- transformers/models/fuyu/configuration_fuyu.py +28 -34
- transformers/models/fuyu/image_processing_fuyu.py +29 -31
- transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
- transformers/models/fuyu/modeling_fuyu.py +50 -52
- transformers/models/fuyu/processing_fuyu.py +9 -36
- transformers/models/gemma/configuration_gemma.py +25 -30
- transformers/models/gemma/modeling_gemma.py +36 -38
- transformers/models/gemma/modular_gemma.py +33 -36
- transformers/models/gemma/tokenization_gemma.py +3 -6
- transformers/models/gemma2/configuration_gemma2.py +30 -35
- transformers/models/gemma2/modeling_gemma2.py +38 -41
- transformers/models/gemma2/modular_gemma2.py +63 -67
- transformers/models/gemma3/configuration_gemma3.py +53 -48
- transformers/models/gemma3/image_processing_gemma3.py +29 -31
- transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
- transformers/models/gemma3/modeling_gemma3.py +123 -122
- transformers/models/gemma3/modular_gemma3.py +128 -125
- transformers/models/gemma3/processing_gemma3.py +5 -5
- transformers/models/gemma3n/configuration_gemma3n.py +42 -30
- transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
- transformers/models/gemma3n/modeling_gemma3n.py +166 -147
- transformers/models/gemma3n/modular_gemma3n.py +176 -148
- transformers/models/gemma3n/processing_gemma3n.py +12 -26
- transformers/models/git/configuration_git.py +5 -8
- transformers/models/git/modeling_git.py +115 -127
- transformers/models/git/processing_git.py +2 -14
- transformers/models/glm/configuration_glm.py +26 -30
- transformers/models/glm/modeling_glm.py +36 -39
- transformers/models/glm/modular_glm.py +4 -7
- transformers/models/glm4/configuration_glm4.py +26 -30
- transformers/models/glm4/modeling_glm4.py +39 -41
- transformers/models/glm4/modular_glm4.py +8 -10
- transformers/models/glm46v/configuration_glm46v.py +4 -1
- transformers/models/glm46v/image_processing_glm46v.py +40 -38
- transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
- transformers/models/glm46v/modeling_glm46v.py +138 -93
- transformers/models/glm46v/modular_glm46v.py +5 -3
- transformers/models/glm46v/processing_glm46v.py +7 -41
- transformers/models/glm46v/video_processing_glm46v.py +9 -11
- transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
- transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
- transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
- transformers/models/glm4_moe_lite/__init__.py +28 -0
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
- transformers/models/glm4v/configuration_glm4v.py +25 -24
- transformers/models/glm4v/image_processing_glm4v.py +39 -38
- transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
- transformers/models/glm4v/modeling_glm4v.py +249 -210
- transformers/models/glm4v/modular_glm4v.py +211 -230
- transformers/models/glm4v/processing_glm4v.py +7 -41
- transformers/models/glm4v/video_processing_glm4v.py +9 -11
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
- transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
- transformers/models/glm_image/__init__.py +31 -0
- transformers/models/glm_image/configuration_glm_image.py +358 -0
- transformers/models/glm_image/image_processing_glm_image.py +503 -0
- transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
- transformers/models/glm_image/modeling_glm_image.py +1691 -0
- transformers/models/glm_image/modular_glm_image.py +1640 -0
- transformers/models/glm_image/processing_glm_image.py +265 -0
- transformers/models/glm_ocr/__init__.py +28 -0
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/__init__.py +0 -1
- transformers/models/glmasr/configuration_glmasr.py +0 -1
- transformers/models/glmasr/modeling_glmasr.py +51 -46
- transformers/models/glmasr/modular_glmasr.py +39 -29
- transformers/models/glmasr/processing_glmasr.py +7 -8
- transformers/models/glpn/configuration_glpn.py +0 -1
- transformers/models/glpn/image_processing_glpn.py +11 -12
- transformers/models/glpn/image_processing_glpn_fast.py +11 -12
- transformers/models/glpn/modeling_glpn.py +14 -14
- transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
- transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
- transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
- transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
- transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
- transformers/models/gpt2/configuration_gpt2.py +13 -2
- transformers/models/gpt2/modeling_gpt2.py +111 -113
- transformers/models/gpt2/tokenization_gpt2.py +6 -9
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
- transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
- transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
- transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
- transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
- transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
- transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
- transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
- transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
- transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
- transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
- transformers/models/gptj/configuration_gptj.py +4 -5
- transformers/models/gptj/modeling_gptj.py +85 -88
- transformers/models/granite/configuration_granite.py +28 -33
- transformers/models/granite/modeling_granite.py +43 -45
- transformers/models/granite/modular_granite.py +29 -31
- transformers/models/granite_speech/configuration_granite_speech.py +0 -1
- transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
- transformers/models/granite_speech/modeling_granite_speech.py +84 -60
- transformers/models/granite_speech/processing_granite_speech.py +11 -4
- transformers/models/granitemoe/configuration_granitemoe.py +31 -36
- transformers/models/granitemoe/modeling_granitemoe.py +39 -41
- transformers/models/granitemoe/modular_granitemoe.py +21 -23
- transformers/models/granitemoehybrid/__init__.py +0 -1
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
- transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
- transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
- transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
- transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
- transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
- transformers/models/groupvit/configuration_groupvit.py +4 -2
- transformers/models/groupvit/modeling_groupvit.py +98 -92
- transformers/models/helium/configuration_helium.py +25 -29
- transformers/models/helium/modeling_helium.py +37 -40
- transformers/models/helium/modular_helium.py +3 -7
- transformers/models/herbert/tokenization_herbert.py +4 -6
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
- transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
- transformers/models/hiera/configuration_hiera.py +2 -5
- transformers/models/hiera/modeling_hiera.py +71 -70
- transformers/models/hubert/configuration_hubert.py +4 -2
- transformers/models/hubert/modeling_hubert.py +42 -41
- transformers/models/hubert/modular_hubert.py +8 -11
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
- transformers/models/ibert/configuration_ibert.py +4 -2
- transformers/models/ibert/modeling_ibert.py +60 -62
- transformers/models/ibert/quant_modules.py +0 -1
- transformers/models/idefics/configuration_idefics.py +5 -8
- transformers/models/idefics/image_processing_idefics.py +13 -15
- transformers/models/idefics/modeling_idefics.py +63 -65
- transformers/models/idefics/perceiver.py +1 -3
- transformers/models/idefics/processing_idefics.py +32 -48
- transformers/models/idefics/vision.py +27 -28
- transformers/models/idefics2/configuration_idefics2.py +1 -3
- transformers/models/idefics2/image_processing_idefics2.py +31 -32
- transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
- transformers/models/idefics2/modeling_idefics2.py +126 -106
- transformers/models/idefics2/processing_idefics2.py +10 -68
- transformers/models/idefics3/configuration_idefics3.py +1 -4
- transformers/models/idefics3/image_processing_idefics3.py +42 -43
- transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
- transformers/models/idefics3/modeling_idefics3.py +113 -92
- transformers/models/idefics3/processing_idefics3.py +15 -69
- transformers/models/ijepa/configuration_ijepa.py +0 -1
- transformers/models/ijepa/modeling_ijepa.py +13 -14
- transformers/models/ijepa/modular_ijepa.py +5 -7
- transformers/models/imagegpt/configuration_imagegpt.py +9 -2
- transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
- transformers/models/imagegpt/modeling_imagegpt.py +65 -62
- transformers/models/informer/configuration_informer.py +6 -9
- transformers/models/informer/modeling_informer.py +87 -89
- transformers/models/informer/modular_informer.py +13 -16
- transformers/models/instructblip/configuration_instructblip.py +2 -2
- transformers/models/instructblip/modeling_instructblip.py +104 -79
- transformers/models/instructblip/processing_instructblip.py +10 -36
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
- transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
- transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
- transformers/models/internvl/configuration_internvl.py +5 -1
- transformers/models/internvl/modeling_internvl.py +76 -98
- transformers/models/internvl/modular_internvl.py +45 -59
- transformers/models/internvl/processing_internvl.py +12 -45
- transformers/models/internvl/video_processing_internvl.py +10 -11
- transformers/models/jais2/configuration_jais2.py +25 -29
- transformers/models/jais2/modeling_jais2.py +36 -38
- transformers/models/jais2/modular_jais2.py +20 -22
- transformers/models/jamba/configuration_jamba.py +5 -8
- transformers/models/jamba/modeling_jamba.py +47 -50
- transformers/models/jamba/modular_jamba.py +40 -41
- transformers/models/janus/configuration_janus.py +0 -1
- transformers/models/janus/image_processing_janus.py +37 -39
- transformers/models/janus/image_processing_janus_fast.py +20 -21
- transformers/models/janus/modeling_janus.py +103 -188
- transformers/models/janus/modular_janus.py +122 -83
- transformers/models/janus/processing_janus.py +17 -43
- transformers/models/jetmoe/configuration_jetmoe.py +26 -27
- transformers/models/jetmoe/modeling_jetmoe.py +42 -45
- transformers/models/jetmoe/modular_jetmoe.py +33 -36
- transformers/models/kosmos2/configuration_kosmos2.py +10 -9
- transformers/models/kosmos2/modeling_kosmos2.py +199 -178
- transformers/models/kosmos2/processing_kosmos2.py +40 -55
- transformers/models/kosmos2_5/__init__.py +0 -1
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
- transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
- transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
- transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
- transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
- transformers/models/lasr/configuration_lasr.py +3 -7
- transformers/models/lasr/feature_extraction_lasr.py +10 -12
- transformers/models/lasr/modeling_lasr.py +21 -24
- transformers/models/lasr/modular_lasr.py +11 -13
- transformers/models/lasr/processing_lasr.py +12 -6
- transformers/models/lasr/tokenization_lasr.py +2 -4
- transformers/models/layoutlm/configuration_layoutlm.py +14 -2
- transformers/models/layoutlm/modeling_layoutlm.py +70 -72
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
- transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
- transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
- transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
- transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
- transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
- transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
- transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
- transformers/models/led/configuration_led.py +8 -12
- transformers/models/led/modeling_led.py +113 -267
- transformers/models/levit/configuration_levit.py +0 -1
- transformers/models/levit/image_processing_levit.py +19 -21
- transformers/models/levit/image_processing_levit_fast.py +4 -5
- transformers/models/levit/modeling_levit.py +17 -19
- transformers/models/lfm2/configuration_lfm2.py +27 -30
- transformers/models/lfm2/modeling_lfm2.py +46 -48
- transformers/models/lfm2/modular_lfm2.py +32 -32
- transformers/models/lfm2_moe/__init__.py +0 -1
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
- transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
- transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
- transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
- transformers/models/lightglue/image_processing_lightglue.py +16 -15
- transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
- transformers/models/lightglue/modeling_lightglue.py +31 -33
- transformers/models/lightglue/modular_lightglue.py +31 -31
- transformers/models/lighton_ocr/__init__.py +28 -0
- transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
- transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
- transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
- transformers/models/lilt/configuration_lilt.py +6 -2
- transformers/models/lilt/modeling_lilt.py +53 -55
- transformers/models/llama/configuration_llama.py +26 -31
- transformers/models/llama/modeling_llama.py +35 -38
- transformers/models/llama/tokenization_llama.py +2 -4
- transformers/models/llama4/configuration_llama4.py +87 -69
- transformers/models/llama4/image_processing_llama4_fast.py +11 -12
- transformers/models/llama4/modeling_llama4.py +116 -115
- transformers/models/llama4/processing_llama4.py +33 -57
- transformers/models/llava/configuration_llava.py +10 -1
- transformers/models/llava/image_processing_llava.py +25 -28
- transformers/models/llava/image_processing_llava_fast.py +9 -10
- transformers/models/llava/modeling_llava.py +73 -102
- transformers/models/llava/processing_llava.py +18 -51
- transformers/models/llava_next/configuration_llava_next.py +2 -2
- transformers/models/llava_next/image_processing_llava_next.py +43 -45
- transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
- transformers/models/llava_next/modeling_llava_next.py +103 -104
- transformers/models/llava_next/processing_llava_next.py +18 -47
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
- transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
- transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
- transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
- transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
- transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
- transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
- transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
- transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
- transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
- transformers/models/longcat_flash/__init__.py +0 -1
- transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
- transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
- transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
- transformers/models/longformer/configuration_longformer.py +5 -5
- transformers/models/longformer/modeling_longformer.py +99 -101
- transformers/models/longt5/configuration_longt5.py +9 -7
- transformers/models/longt5/modeling_longt5.py +45 -45
- transformers/models/luke/configuration_luke.py +8 -2
- transformers/models/luke/modeling_luke.py +179 -181
- transformers/models/luke/tokenization_luke.py +99 -105
- transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
- transformers/models/lw_detr/configuration_lw_detr.py +362 -0
- transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
- transformers/models/lw_detr/modular_lw_detr.py +1609 -0
- transformers/models/lxmert/configuration_lxmert.py +16 -1
- transformers/models/lxmert/modeling_lxmert.py +63 -74
- transformers/models/m2m_100/configuration_m2m_100.py +7 -9
- transformers/models/m2m_100/modeling_m2m_100.py +72 -74
- transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
- transformers/models/mamba/configuration_mamba.py +5 -3
- transformers/models/mamba/modeling_mamba.py +61 -70
- transformers/models/mamba2/configuration_mamba2.py +5 -8
- transformers/models/mamba2/modeling_mamba2.py +66 -79
- transformers/models/marian/configuration_marian.py +10 -5
- transformers/models/marian/modeling_marian.py +88 -90
- transformers/models/marian/tokenization_marian.py +6 -6
- transformers/models/markuplm/configuration_markuplm.py +4 -7
- transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
- transformers/models/markuplm/modeling_markuplm.py +63 -65
- transformers/models/markuplm/processing_markuplm.py +31 -38
- transformers/models/markuplm/tokenization_markuplm.py +67 -77
- transformers/models/mask2former/configuration_mask2former.py +14 -52
- transformers/models/mask2former/image_processing_mask2former.py +84 -85
- transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
- transformers/models/mask2former/modeling_mask2former.py +108 -104
- transformers/models/mask2former/modular_mask2former.py +6 -8
- transformers/models/maskformer/configuration_maskformer.py +17 -51
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
- transformers/models/maskformer/image_processing_maskformer.py +84 -85
- transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
- transformers/models/maskformer/modeling_maskformer.py +71 -67
- transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
- transformers/models/mbart/configuration_mbart.py +9 -5
- transformers/models/mbart/modeling_mbart.py +120 -119
- transformers/models/mbart/tokenization_mbart.py +2 -4
- transformers/models/mbart50/tokenization_mbart50.py +3 -5
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
- transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
- transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
- transformers/models/mgp_str/configuration_mgp_str.py +0 -1
- transformers/models/mgp_str/modeling_mgp_str.py +18 -18
- transformers/models/mgp_str/processing_mgp_str.py +3 -20
- transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
- transformers/models/mimi/configuration_mimi.py +42 -40
- transformers/models/mimi/modeling_mimi.py +116 -115
- transformers/models/minimax/__init__.py +0 -1
- transformers/models/minimax/configuration_minimax.py +40 -47
- transformers/models/minimax/modeling_minimax.py +46 -49
- transformers/models/minimax/modular_minimax.py +59 -65
- transformers/models/minimax_m2/__init__.py +28 -0
- transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
- transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
- transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
- transformers/models/ministral/configuration_ministral.py +25 -29
- transformers/models/ministral/modeling_ministral.py +35 -37
- transformers/models/ministral/modular_ministral.py +32 -37
- transformers/models/ministral3/configuration_ministral3.py +23 -26
- transformers/models/ministral3/modeling_ministral3.py +35 -37
- transformers/models/ministral3/modular_ministral3.py +7 -8
- transformers/models/mistral/configuration_mistral.py +24 -29
- transformers/models/mistral/modeling_mistral.py +35 -37
- transformers/models/mistral/modular_mistral.py +14 -15
- transformers/models/mistral3/configuration_mistral3.py +4 -1
- transformers/models/mistral3/modeling_mistral3.py +79 -82
- transformers/models/mistral3/modular_mistral3.py +66 -67
- transformers/models/mixtral/configuration_mixtral.py +32 -38
- transformers/models/mixtral/modeling_mixtral.py +39 -42
- transformers/models/mixtral/modular_mixtral.py +26 -29
- transformers/models/mlcd/configuration_mlcd.py +0 -1
- transformers/models/mlcd/modeling_mlcd.py +17 -17
- transformers/models/mlcd/modular_mlcd.py +16 -16
- transformers/models/mllama/configuration_mllama.py +10 -15
- transformers/models/mllama/image_processing_mllama.py +23 -25
- transformers/models/mllama/image_processing_mllama_fast.py +11 -11
- transformers/models/mllama/modeling_mllama.py +100 -103
- transformers/models/mllama/processing_mllama.py +6 -55
- transformers/models/mluke/tokenization_mluke.py +97 -103
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
- transformers/models/mobilebert/configuration_mobilebert.py +4 -2
- transformers/models/mobilebert/modeling_mobilebert.py +78 -88
- transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
- transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
- transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
- transformers/models/mobilevit/configuration_mobilevit.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
- transformers/models/mobilevit/modeling_mobilevit.py +21 -21
- transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
- transformers/models/modernbert/configuration_modernbert.py +76 -51
- transformers/models/modernbert/modeling_modernbert.py +188 -943
- transformers/models/modernbert/modular_modernbert.py +255 -978
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
- transformers/models/moonshine/configuration_moonshine.py +34 -31
- transformers/models/moonshine/modeling_moonshine.py +70 -72
- transformers/models/moonshine/modular_moonshine.py +91 -86
- transformers/models/moshi/configuration_moshi.py +46 -23
- transformers/models/moshi/modeling_moshi.py +134 -142
- transformers/models/mpnet/configuration_mpnet.py +6 -2
- transformers/models/mpnet/modeling_mpnet.py +55 -57
- transformers/models/mpnet/tokenization_mpnet.py +1 -4
- transformers/models/mpt/configuration_mpt.py +17 -9
- transformers/models/mpt/modeling_mpt.py +58 -60
- transformers/models/mra/configuration_mra.py +8 -2
- transformers/models/mra/modeling_mra.py +54 -56
- transformers/models/mt5/configuration_mt5.py +9 -6
- transformers/models/mt5/modeling_mt5.py +80 -85
- transformers/models/musicgen/configuration_musicgen.py +12 -8
- transformers/models/musicgen/modeling_musicgen.py +114 -116
- transformers/models/musicgen/processing_musicgen.py +3 -21
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
- transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
- transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
- transformers/models/mvp/configuration_mvp.py +8 -5
- transformers/models/mvp/modeling_mvp.py +121 -123
- transformers/models/myt5/tokenization_myt5.py +8 -10
- transformers/models/nanochat/configuration_nanochat.py +5 -8
- transformers/models/nanochat/modeling_nanochat.py +36 -39
- transformers/models/nanochat/modular_nanochat.py +16 -18
- transformers/models/nemotron/configuration_nemotron.py +25 -30
- transformers/models/nemotron/modeling_nemotron.py +53 -66
- transformers/models/nllb/tokenization_nllb.py +14 -14
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
- transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
- transformers/models/nougat/image_processing_nougat.py +29 -32
- transformers/models/nougat/image_processing_nougat_fast.py +12 -13
- transformers/models/nougat/processing_nougat.py +37 -39
- transformers/models/nougat/tokenization_nougat.py +5 -7
- transformers/models/nystromformer/configuration_nystromformer.py +8 -2
- transformers/models/nystromformer/modeling_nystromformer.py +61 -63
- transformers/models/olmo/configuration_olmo.py +23 -28
- transformers/models/olmo/modeling_olmo.py +35 -38
- transformers/models/olmo/modular_olmo.py +8 -12
- transformers/models/olmo2/configuration_olmo2.py +27 -32
- transformers/models/olmo2/modeling_olmo2.py +36 -39
- transformers/models/olmo2/modular_olmo2.py +36 -38
- transformers/models/olmo3/__init__.py +0 -1
- transformers/models/olmo3/configuration_olmo3.py +30 -34
- transformers/models/olmo3/modeling_olmo3.py +35 -38
- transformers/models/olmo3/modular_olmo3.py +44 -47
- transformers/models/olmoe/configuration_olmoe.py +29 -33
- transformers/models/olmoe/modeling_olmoe.py +41 -43
- transformers/models/olmoe/modular_olmoe.py +15 -16
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
- transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
- transformers/models/oneformer/configuration_oneformer.py +11 -51
- transformers/models/oneformer/image_processing_oneformer.py +83 -84
- transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
- transformers/models/oneformer/modeling_oneformer.py +137 -133
- transformers/models/oneformer/processing_oneformer.py +28 -43
- transformers/models/openai/configuration_openai.py +16 -1
- transformers/models/openai/modeling_openai.py +50 -51
- transformers/models/openai/tokenization_openai.py +2 -5
- transformers/models/opt/configuration_opt.py +6 -7
- transformers/models/opt/modeling_opt.py +79 -80
- transformers/models/ovis2/__init__.py +0 -1
- transformers/models/ovis2/configuration_ovis2.py +4 -1
- transformers/models/ovis2/image_processing_ovis2.py +22 -24
- transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
- transformers/models/ovis2/modeling_ovis2.py +99 -142
- transformers/models/ovis2/modular_ovis2.py +82 -45
- transformers/models/ovis2/processing_ovis2.py +12 -40
- transformers/models/owlv2/configuration_owlv2.py +4 -2
- transformers/models/owlv2/image_processing_owlv2.py +20 -21
- transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
- transformers/models/owlv2/modeling_owlv2.py +122 -114
- transformers/models/owlv2/modular_owlv2.py +11 -12
- transformers/models/owlv2/processing_owlv2.py +20 -49
- transformers/models/owlvit/configuration_owlvit.py +4 -2
- transformers/models/owlvit/image_processing_owlvit.py +21 -22
- transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
- transformers/models/owlvit/modeling_owlvit.py +121 -113
- transformers/models/owlvit/processing_owlvit.py +20 -48
- transformers/models/paddleocr_vl/__init__.py +0 -1
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
- transformers/models/paligemma/configuration_paligemma.py +4 -1
- transformers/models/paligemma/modeling_paligemma.py +81 -79
- transformers/models/paligemma/processing_paligemma.py +13 -66
- transformers/models/parakeet/configuration_parakeet.py +3 -8
- transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
- transformers/models/parakeet/modeling_parakeet.py +21 -25
- transformers/models/parakeet/modular_parakeet.py +19 -21
- transformers/models/parakeet/processing_parakeet.py +12 -5
- transformers/models/parakeet/tokenization_parakeet.py +2 -4
- transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
- transformers/models/patchtst/configuration_patchtst.py +6 -9
- transformers/models/patchtst/modeling_patchtst.py +75 -77
- transformers/models/pe_audio/__init__.py +0 -1
- transformers/models/pe_audio/configuration_pe_audio.py +14 -16
- transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
- transformers/models/pe_audio/modeling_pe_audio.py +30 -31
- transformers/models/pe_audio/modular_pe_audio.py +17 -18
- transformers/models/pe_audio/processing_pe_audio.py +0 -1
- transformers/models/pe_audio_video/__init__.py +0 -1
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
- transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
- transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
- transformers/models/pe_video/__init__.py +0 -1
- transformers/models/pe_video/configuration_pe_video.py +14 -16
- transformers/models/pe_video/modeling_pe_video.py +57 -46
- transformers/models/pe_video/modular_pe_video.py +47 -35
- transformers/models/pe_video/video_processing_pe_video.py +2 -4
- transformers/models/pegasus/configuration_pegasus.py +8 -6
- transformers/models/pegasus/modeling_pegasus.py +67 -69
- transformers/models/pegasus/tokenization_pegasus.py +1 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
- transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
- transformers/models/perceiver/configuration_perceiver.py +0 -1
- transformers/models/perceiver/image_processing_perceiver.py +22 -25
- transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
- transformers/models/perceiver/modeling_perceiver.py +152 -145
- transformers/models/perceiver/tokenization_perceiver.py +3 -6
- transformers/models/perception_lm/configuration_perception_lm.py +0 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
- transformers/models/perception_lm/modeling_perception_lm.py +64 -67
- transformers/models/perception_lm/modular_perception_lm.py +58 -58
- transformers/models/perception_lm/processing_perception_lm.py +13 -47
- transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
- transformers/models/persimmon/configuration_persimmon.py +23 -28
- transformers/models/persimmon/modeling_persimmon.py +44 -47
- transformers/models/phi/configuration_phi.py +27 -28
- transformers/models/phi/modeling_phi.py +39 -41
- transformers/models/phi/modular_phi.py +26 -26
- transformers/models/phi3/configuration_phi3.py +32 -37
- transformers/models/phi3/modeling_phi3.py +37 -40
- transformers/models/phi3/modular_phi3.py +16 -20
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
- transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
- transformers/models/phimoe/configuration_phimoe.py +31 -36
- transformers/models/phimoe/modeling_phimoe.py +50 -77
- transformers/models/phimoe/modular_phimoe.py +12 -8
- transformers/models/phobert/tokenization_phobert.py +4 -6
- transformers/models/pix2struct/configuration_pix2struct.py +12 -10
- transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
- transformers/models/pix2struct/modeling_pix2struct.py +56 -52
- transformers/models/pix2struct/processing_pix2struct.py +5 -26
- transformers/models/pixio/__init__.py +0 -1
- transformers/models/pixio/configuration_pixio.py +2 -5
- transformers/models/pixio/modeling_pixio.py +16 -17
- transformers/models/pixio/modular_pixio.py +7 -8
- transformers/models/pixtral/configuration_pixtral.py +11 -14
- transformers/models/pixtral/image_processing_pixtral.py +26 -28
- transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
- transformers/models/pixtral/modeling_pixtral.py +31 -37
- transformers/models/pixtral/processing_pixtral.py +18 -52
- transformers/models/plbart/configuration_plbart.py +8 -6
- transformers/models/plbart/modeling_plbart.py +109 -109
- transformers/models/plbart/modular_plbart.py +31 -33
- transformers/models/plbart/tokenization_plbart.py +4 -5
- transformers/models/poolformer/configuration_poolformer.py +0 -1
- transformers/models/poolformer/image_processing_poolformer.py +21 -24
- transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
- transformers/models/poolformer/modeling_poolformer.py +10 -12
- transformers/models/pop2piano/configuration_pop2piano.py +7 -7
- transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
- transformers/models/pop2piano/modeling_pop2piano.py +24 -24
- transformers/models/pop2piano/processing_pop2piano.py +25 -33
- transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
- transformers/models/prophetnet/configuration_prophetnet.py +37 -38
- transformers/models/prophetnet/modeling_prophetnet.py +121 -153
- transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
- transformers/models/pvt/configuration_pvt.py +0 -1
- transformers/models/pvt/image_processing_pvt.py +24 -27
- transformers/models/pvt/image_processing_pvt_fast.py +1 -2
- transformers/models/pvt/modeling_pvt.py +19 -21
- transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
- transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
- transformers/models/qwen2/configuration_qwen2.py +32 -25
- transformers/models/qwen2/modeling_qwen2.py +35 -37
- transformers/models/qwen2/modular_qwen2.py +14 -15
- transformers/models/qwen2/tokenization_qwen2.py +2 -9
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
- transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
- transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
- transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
- transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
- transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
- transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
- transformers/models/qwen3/configuration_qwen3.py +34 -27
- transformers/models/qwen3/modeling_qwen3.py +35 -38
- transformers/models/qwen3/modular_qwen3.py +7 -9
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
- transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
- transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
- transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
- transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
- transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
- transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
- transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
- transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
- transformers/models/rag/configuration_rag.py +6 -7
- transformers/models/rag/modeling_rag.py +119 -121
- transformers/models/rag/retrieval_rag.py +3 -5
- transformers/models/rag/tokenization_rag.py +0 -50
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
- transformers/models/reformer/configuration_reformer.py +7 -8
- transformers/models/reformer/modeling_reformer.py +67 -68
- transformers/models/reformer/tokenization_reformer.py +3 -6
- transformers/models/regnet/configuration_regnet.py +0 -1
- transformers/models/regnet/modeling_regnet.py +7 -9
- transformers/models/rembert/configuration_rembert.py +8 -2
- transformers/models/rembert/modeling_rembert.py +108 -132
- transformers/models/rembert/tokenization_rembert.py +1 -4
- transformers/models/resnet/configuration_resnet.py +2 -5
- transformers/models/resnet/modeling_resnet.py +14 -15
- transformers/models/roberta/configuration_roberta.py +11 -3
- transformers/models/roberta/modeling_roberta.py +97 -99
- transformers/models/roberta/modular_roberta.py +55 -58
- transformers/models/roberta/tokenization_roberta.py +2 -5
- transformers/models/roberta/tokenization_roberta_old.py +2 -4
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
- transformers/models/roc_bert/configuration_roc_bert.py +8 -2
- transformers/models/roc_bert/modeling_roc_bert.py +125 -162
- transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
- transformers/models/roformer/configuration_roformer.py +13 -3
- transformers/models/roformer/modeling_roformer.py +79 -95
- transformers/models/roformer/tokenization_roformer.py +3 -6
- transformers/models/roformer/tokenization_utils.py +0 -1
- transformers/models/rt_detr/configuration_rt_detr.py +8 -50
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
- transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
- transformers/models/rt_detr/modeling_rt_detr.py +643 -804
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
- transformers/models/rt_detr/modular_rt_detr.py +1522 -20
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
- transformers/models/rwkv/configuration_rwkv.py +2 -4
- transformers/models/rwkv/modeling_rwkv.py +29 -54
- transformers/models/sam/configuration_sam.py +2 -1
- transformers/models/sam/image_processing_sam.py +59 -60
- transformers/models/sam/image_processing_sam_fast.py +25 -26
- transformers/models/sam/modeling_sam.py +46 -43
- transformers/models/sam/processing_sam.py +39 -27
- transformers/models/sam2/configuration_sam2.py +1 -2
- transformers/models/sam2/image_processing_sam2_fast.py +14 -15
- transformers/models/sam2/modeling_sam2.py +96 -94
- transformers/models/sam2/modular_sam2.py +85 -94
- transformers/models/sam2/processing_sam2.py +31 -47
- transformers/models/sam2_video/configuration_sam2_video.py +0 -1
- transformers/models/sam2_video/modeling_sam2_video.py +114 -116
- transformers/models/sam2_video/modular_sam2_video.py +72 -89
- transformers/models/sam2_video/processing_sam2_video.py +49 -66
- transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
- transformers/models/sam3/configuration_sam3.py +0 -1
- transformers/models/sam3/image_processing_sam3_fast.py +17 -20
- transformers/models/sam3/modeling_sam3.py +94 -100
- transformers/models/sam3/modular_sam3.py +3 -8
- transformers/models/sam3/processing_sam3.py +37 -52
- transformers/models/sam3_tracker/__init__.py +0 -1
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
- transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
- transformers/models/sam3_tracker_video/__init__.py +0 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
- transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
- transformers/models/sam3_video/configuration_sam3_video.py +0 -1
- transformers/models/sam3_video/modeling_sam3_video.py +56 -45
- transformers/models/sam3_video/processing_sam3_video.py +25 -45
- transformers/models/sam_hq/__init__.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +2 -1
- transformers/models/sam_hq/modeling_sam_hq.py +52 -50
- transformers/models/sam_hq/modular_sam_hq.py +23 -25
- transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
- transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
- transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
- transformers/models/seed_oss/configuration_seed_oss.py +30 -34
- transformers/models/seed_oss/modeling_seed_oss.py +34 -36
- transformers/models/seed_oss/modular_seed_oss.py +6 -7
- transformers/models/segformer/configuration_segformer.py +0 -10
- transformers/models/segformer/image_processing_segformer.py +39 -42
- transformers/models/segformer/image_processing_segformer_fast.py +11 -12
- transformers/models/segformer/modeling_segformer.py +28 -28
- transformers/models/segformer/modular_segformer.py +8 -9
- transformers/models/seggpt/configuration_seggpt.py +0 -1
- transformers/models/seggpt/image_processing_seggpt.py +38 -41
- transformers/models/seggpt/modeling_seggpt.py +48 -38
- transformers/models/sew/configuration_sew.py +4 -2
- transformers/models/sew/modeling_sew.py +42 -40
- transformers/models/sew/modular_sew.py +12 -13
- transformers/models/sew_d/configuration_sew_d.py +4 -2
- transformers/models/sew_d/modeling_sew_d.py +32 -31
- transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
- transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
- transformers/models/siglip/configuration_siglip.py +4 -2
- transformers/models/siglip/image_processing_siglip.py +17 -20
- transformers/models/siglip/image_processing_siglip_fast.py +0 -1
- transformers/models/siglip/modeling_siglip.py +65 -110
- transformers/models/siglip/processing_siglip.py +2 -14
- transformers/models/siglip/tokenization_siglip.py +6 -7
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2.py +15 -16
- transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
- transformers/models/siglip2/modeling_siglip2.py +89 -130
- transformers/models/siglip2/modular_siglip2.py +95 -48
- transformers/models/siglip2/processing_siglip2.py +2 -14
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +29 -32
- transformers/models/smollm3/modeling_smollm3.py +35 -38
- transformers/models/smollm3/modular_smollm3.py +36 -38
- transformers/models/smolvlm/configuration_smolvlm.py +2 -4
- transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
- transformers/models/smolvlm/modeling_smolvlm.py +124 -96
- transformers/models/smolvlm/modular_smolvlm.py +50 -39
- transformers/models/smolvlm/processing_smolvlm.py +15 -76
- transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
- transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
- transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
- transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
- transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
- transformers/models/speecht5/configuration_speecht5.py +7 -9
- transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
- transformers/models/speecht5/modeling_speecht5.py +172 -174
- transformers/models/speecht5/number_normalizer.py +0 -1
- transformers/models/speecht5/processing_speecht5.py +3 -37
- transformers/models/speecht5/tokenization_speecht5.py +4 -5
- transformers/models/splinter/configuration_splinter.py +6 -7
- transformers/models/splinter/modeling_splinter.py +62 -59
- transformers/models/splinter/tokenization_splinter.py +2 -4
- transformers/models/squeezebert/configuration_squeezebert.py +14 -2
- transformers/models/squeezebert/modeling_squeezebert.py +60 -62
- transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
- transformers/models/stablelm/configuration_stablelm.py +28 -29
- transformers/models/stablelm/modeling_stablelm.py +44 -47
- transformers/models/starcoder2/configuration_starcoder2.py +30 -27
- transformers/models/starcoder2/modeling_starcoder2.py +38 -41
- transformers/models/starcoder2/modular_starcoder2.py +17 -19
- transformers/models/superglue/configuration_superglue.py +7 -3
- transformers/models/superglue/image_processing_superglue.py +15 -15
- transformers/models/superglue/image_processing_superglue_fast.py +8 -8
- transformers/models/superglue/modeling_superglue.py +41 -37
- transformers/models/superpoint/image_processing_superpoint.py +15 -15
- transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
- transformers/models/superpoint/modeling_superpoint.py +17 -16
- transformers/models/swiftformer/configuration_swiftformer.py +0 -1
- transformers/models/swiftformer/modeling_swiftformer.py +12 -14
- transformers/models/swin/configuration_swin.py +2 -5
- transformers/models/swin/modeling_swin.py +69 -78
- transformers/models/swin2sr/configuration_swin2sr.py +0 -1
- transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
- transformers/models/swin2sr/modeling_swin2sr.py +30 -30
- transformers/models/swinv2/configuration_swinv2.py +2 -5
- transformers/models/swinv2/modeling_swinv2.py +65 -74
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
- transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
- transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
- transformers/models/t5/configuration_t5.py +9 -9
- transformers/models/t5/modeling_t5.py +80 -85
- transformers/models/t5/tokenization_t5.py +1 -3
- transformers/models/t5gemma/configuration_t5gemma.py +43 -59
- transformers/models/t5gemma/modeling_t5gemma.py +105 -108
- transformers/models/t5gemma/modular_t5gemma.py +128 -142
- transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
- transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
- transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
- transformers/models/table_transformer/configuration_table_transformer.py +18 -50
- transformers/models/table_transformer/modeling_table_transformer.py +73 -101
- transformers/models/tapas/configuration_tapas.py +12 -2
- transformers/models/tapas/modeling_tapas.py +65 -67
- transformers/models/tapas/tokenization_tapas.py +116 -153
- transformers/models/textnet/configuration_textnet.py +4 -7
- transformers/models/textnet/image_processing_textnet.py +22 -25
- transformers/models/textnet/image_processing_textnet_fast.py +8 -9
- transformers/models/textnet/modeling_textnet.py +28 -28
- transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
- transformers/models/timesfm/configuration_timesfm.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +22 -25
- transformers/models/timesfm/modular_timesfm.py +21 -24
- transformers/models/timesformer/configuration_timesformer.py +0 -1
- transformers/models/timesformer/modeling_timesformer.py +13 -16
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
- transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
- transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
- transformers/models/trocr/configuration_trocr.py +11 -8
- transformers/models/trocr/modeling_trocr.py +42 -42
- transformers/models/trocr/processing_trocr.py +5 -25
- transformers/models/tvp/configuration_tvp.py +10 -36
- transformers/models/tvp/image_processing_tvp.py +50 -52
- transformers/models/tvp/image_processing_tvp_fast.py +15 -15
- transformers/models/tvp/modeling_tvp.py +26 -28
- transformers/models/tvp/processing_tvp.py +2 -14
- transformers/models/udop/configuration_udop.py +16 -8
- transformers/models/udop/modeling_udop.py +73 -72
- transformers/models/udop/processing_udop.py +7 -26
- transformers/models/udop/tokenization_udop.py +80 -93
- transformers/models/umt5/configuration_umt5.py +8 -7
- transformers/models/umt5/modeling_umt5.py +87 -84
- transformers/models/unispeech/configuration_unispeech.py +4 -2
- transformers/models/unispeech/modeling_unispeech.py +54 -53
- transformers/models/unispeech/modular_unispeech.py +20 -22
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
- transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
- transformers/models/univnet/feature_extraction_univnet.py +14 -14
- transformers/models/univnet/modeling_univnet.py +7 -8
- transformers/models/upernet/configuration_upernet.py +8 -36
- transformers/models/upernet/modeling_upernet.py +11 -14
- transformers/models/vaultgemma/__init__.py +0 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
- transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
- transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
- transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
- transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
- transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
- transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
- transformers/models/video_llava/configuration_video_llava.py +4 -1
- transformers/models/video_llava/image_processing_video_llava.py +35 -38
- transformers/models/video_llava/modeling_video_llava.py +139 -143
- transformers/models/video_llava/processing_video_llava.py +38 -78
- transformers/models/video_llava/video_processing_video_llava.py +0 -1
- transformers/models/videomae/configuration_videomae.py +0 -1
- transformers/models/videomae/image_processing_videomae.py +31 -34
- transformers/models/videomae/modeling_videomae.py +17 -20
- transformers/models/videomae/video_processing_videomae.py +0 -1
- transformers/models/vilt/configuration_vilt.py +4 -2
- transformers/models/vilt/image_processing_vilt.py +29 -30
- transformers/models/vilt/image_processing_vilt_fast.py +15 -16
- transformers/models/vilt/modeling_vilt.py +103 -90
- transformers/models/vilt/processing_vilt.py +2 -14
- transformers/models/vipllava/configuration_vipllava.py +4 -1
- transformers/models/vipllava/modeling_vipllava.py +92 -67
- transformers/models/vipllava/modular_vipllava.py +78 -54
- transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
- transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
- transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
- transformers/models/visual_bert/configuration_visual_bert.py +6 -2
- transformers/models/visual_bert/modeling_visual_bert.py +90 -92
- transformers/models/vit/configuration_vit.py +2 -3
- transformers/models/vit/image_processing_vit.py +19 -22
- transformers/models/vit/image_processing_vit_fast.py +0 -1
- transformers/models/vit/modeling_vit.py +20 -20
- transformers/models/vit_mae/configuration_vit_mae.py +0 -1
- transformers/models/vit_mae/modeling_vit_mae.py +32 -30
- transformers/models/vit_msn/configuration_vit_msn.py +0 -1
- transformers/models/vit_msn/modeling_vit_msn.py +21 -19
- transformers/models/vitdet/configuration_vitdet.py +2 -5
- transformers/models/vitdet/modeling_vitdet.py +14 -17
- transformers/models/vitmatte/configuration_vitmatte.py +7 -39
- transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
- transformers/models/vitmatte/modeling_vitmatte.py +10 -12
- transformers/models/vitpose/configuration_vitpose.py +7 -47
- transformers/models/vitpose/image_processing_vitpose.py +24 -25
- transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
- transformers/models/vitpose/modeling_vitpose.py +15 -15
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
- transformers/models/vits/configuration_vits.py +4 -1
- transformers/models/vits/modeling_vits.py +43 -42
- transformers/models/vits/tokenization_vits.py +3 -4
- transformers/models/vivit/configuration_vivit.py +0 -1
- transformers/models/vivit/image_processing_vivit.py +36 -39
- transformers/models/vivit/modeling_vivit.py +9 -11
- transformers/models/vjepa2/__init__.py +0 -1
- transformers/models/vjepa2/configuration_vjepa2.py +0 -1
- transformers/models/vjepa2/modeling_vjepa2.py +39 -41
- transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
- transformers/models/voxtral/__init__.py +0 -1
- transformers/models/voxtral/configuration_voxtral.py +0 -2
- transformers/models/voxtral/modeling_voxtral.py +41 -48
- transformers/models/voxtral/modular_voxtral.py +35 -38
- transformers/models/voxtral/processing_voxtral.py +25 -48
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
- transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
- transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
- transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
- transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
- transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
- transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
- transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
- transformers/models/wavlm/configuration_wavlm.py +4 -2
- transformers/models/wavlm/modeling_wavlm.py +49 -49
- transformers/models/wavlm/modular_wavlm.py +4 -5
- transformers/models/whisper/configuration_whisper.py +6 -5
- transformers/models/whisper/english_normalizer.py +3 -4
- transformers/models/whisper/feature_extraction_whisper.py +9 -24
- transformers/models/whisper/generation_whisper.py +26 -49
- transformers/models/whisper/modeling_whisper.py +71 -73
- transformers/models/whisper/processing_whisper.py +3 -20
- transformers/models/whisper/tokenization_whisper.py +9 -30
- transformers/models/x_clip/configuration_x_clip.py +4 -2
- transformers/models/x_clip/modeling_x_clip.py +94 -96
- transformers/models/x_clip/processing_x_clip.py +2 -14
- transformers/models/xcodec/configuration_xcodec.py +4 -6
- transformers/models/xcodec/modeling_xcodec.py +15 -17
- transformers/models/xglm/configuration_xglm.py +9 -8
- transformers/models/xglm/modeling_xglm.py +49 -55
- transformers/models/xglm/tokenization_xglm.py +1 -4
- transformers/models/xlm/configuration_xlm.py +10 -8
- transformers/models/xlm/modeling_xlm.py +127 -131
- transformers/models/xlm/tokenization_xlm.py +3 -5
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
- transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
- transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
- transformers/models/xlnet/configuration_xlnet.py +3 -12
- transformers/models/xlnet/modeling_xlnet.py +149 -162
- transformers/models/xlnet/tokenization_xlnet.py +1 -4
- transformers/models/xlstm/configuration_xlstm.py +8 -12
- transformers/models/xlstm/modeling_xlstm.py +61 -96
- transformers/models/xmod/configuration_xmod.py +11 -3
- transformers/models/xmod/modeling_xmod.py +111 -116
- transformers/models/yolos/configuration_yolos.py +0 -1
- transformers/models/yolos/image_processing_yolos.py +60 -62
- transformers/models/yolos/image_processing_yolos_fast.py +42 -45
- transformers/models/yolos/modeling_yolos.py +19 -21
- transformers/models/yolos/modular_yolos.py +17 -19
- transformers/models/yoso/configuration_yoso.py +8 -2
- transformers/models/yoso/modeling_yoso.py +60 -62
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -8
- transformers/models/zamba/modeling_zamba.py +93 -125
- transformers/models/zamba2/configuration_zamba2.py +44 -50
- transformers/models/zamba2/modeling_zamba2.py +137 -165
- transformers/models/zamba2/modular_zamba2.py +79 -74
- transformers/models/zoedepth/configuration_zoedepth.py +17 -41
- transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
- transformers/models/zoedepth/modeling_zoedepth.py +19 -19
- transformers/pipelines/__init__.py +47 -106
- transformers/pipelines/any_to_any.py +15 -23
- transformers/pipelines/audio_utils.py +1 -2
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +13 -17
- transformers/pipelines/image_text_to_text.py +1 -2
- transformers/pipelines/question_answering.py +4 -43
- transformers/pipelines/text_classification.py +1 -14
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/pipelines/token_classification.py +1 -22
- transformers/pipelines/video_classification.py +1 -9
- transformers/pipelines/zero_shot_audio_classification.py +0 -1
- transformers/pipelines/zero_shot_classification.py +0 -6
- transformers/pipelines/zero_shot_image_classification.py +0 -7
- transformers/processing_utils.py +128 -137
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/base.py +10 -0
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_quark.py +0 -1
- transformers/quantizers/quantizer_torchao.py +3 -19
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +6 -65
- transformers/tokenization_mistral_common.py +563 -903
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +228 -341
- transformers/tokenization_utils_sentencepiece.py +5 -6
- transformers/tokenization_utils_tokenizers.py +36 -7
- transformers/trainer.py +30 -41
- transformers/trainer_jit_checkpoint.py +1 -2
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +414 -420
- transformers/utils/__init__.py +1 -4
- transformers/utils/attention_visualizer.py +1 -1
- transformers/utils/auto_docstring.py +567 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/dummy_pt_objects.py +0 -42
- transformers/utils/generic.py +70 -34
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +135 -107
- transformers/utils/quantization_config.py +8 -31
- transformers/video_processing_utils.py +24 -25
- transformers/video_utils.py +21 -23
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
- transformers-5.1.0.dist-info/RECORD +2092 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -229
- transformers-5.0.0rc2.dist-info/RECORD +0 -2042
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,9 @@
|
|
|
1
|
-
#
|
|
1
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
2
|
+
# This file was automatically generated from src/transformers/models/conditional_detr/modular_conditional_detr.py.
|
|
3
|
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
4
|
+
# the file from the modular. If any change should be done, please apply the change to the
|
|
5
|
+
# modular_conditional_detr.py file directly. One of our CI enforces this.
|
|
6
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
2
7
|
# Copyright 2022 Microsoft Research Asia and The HuggingFace Inc. team. All rights reserved.
|
|
3
8
|
#
|
|
4
9
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -12,40 +17,33 @@
|
|
|
12
17
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
18
|
# See the License for the specific language governing permissions and
|
|
14
19
|
# limitations under the License.
|
|
15
|
-
"""PyTorch Conditional DETR model."""
|
|
16
|
-
|
|
17
20
|
import math
|
|
21
|
+
from collections.abc import Callable
|
|
18
22
|
from dataclasses import dataclass
|
|
19
|
-
from typing import Optional, Union
|
|
20
23
|
|
|
21
24
|
import torch
|
|
22
|
-
from torch import
|
|
25
|
+
from torch import nn
|
|
23
26
|
|
|
24
27
|
from ... import initialization as init
|
|
25
28
|
from ...activations import ACT2FN
|
|
26
|
-
from ...
|
|
29
|
+
from ...backbone_utils import load_backbone
|
|
30
|
+
from ...masking_utils import create_bidirectional_mask
|
|
27
31
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
28
32
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
|
|
29
|
-
from ...modeling_utils import PreTrainedModel
|
|
30
|
-
from ...
|
|
31
|
-
from ...
|
|
33
|
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
34
|
+
from ...processing_utils import Unpack
|
|
35
|
+
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
36
|
+
from ...utils import ModelOutput, TransformersKwargs, auto_docstring
|
|
37
|
+
from ...utils.generic import OutputRecorder, can_return_tuple, check_model_inputs
|
|
32
38
|
from .configuration_conditional_detr import ConditionalDetrConfig
|
|
33
39
|
|
|
34
40
|
|
|
35
|
-
if is_timm_available():
|
|
36
|
-
from timm import create_model
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
logger = logging.get_logger(__name__)
|
|
40
|
-
|
|
41
|
-
|
|
42
41
|
@dataclass
|
|
43
42
|
@auto_docstring(
|
|
44
43
|
custom_intro="""
|
|
45
|
-
Base class for outputs of the
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
decoding losses.
|
|
44
|
+
Base class for outputs of the CONDITIONAL_DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
|
|
45
|
+
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
|
|
46
|
+
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
|
|
49
47
|
"""
|
|
50
48
|
)
|
|
51
49
|
class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
|
|
@@ -61,17 +59,17 @@ class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
|
|
|
61
59
|
Reference points (reference points of each layer of the decoder).
|
|
62
60
|
"""
|
|
63
61
|
|
|
64
|
-
intermediate_hidden_states:
|
|
65
|
-
|
|
62
|
+
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
63
|
+
|
|
64
|
+
reference_points: tuple[torch.FloatTensor] | None = None
|
|
66
65
|
|
|
67
66
|
|
|
68
67
|
@dataclass
|
|
69
68
|
@auto_docstring(
|
|
70
69
|
custom_intro="""
|
|
71
|
-
Base class for outputs of the
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
losses.
|
|
70
|
+
Base class for outputs of the CONDITIONAL_DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
|
|
71
|
+
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
|
|
72
|
+
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
|
|
75
73
|
"""
|
|
76
74
|
)
|
|
77
75
|
class ConditionalDetrModelOutput(Seq2SeqModelOutput):
|
|
@@ -85,8 +83,9 @@ class ConditionalDetrModelOutput(Seq2SeqModelOutput):
|
|
|
85
83
|
Reference points (reference points of each layer of the decoder).
|
|
86
84
|
"""
|
|
87
85
|
|
|
88
|
-
intermediate_hidden_states:
|
|
89
|
-
|
|
86
|
+
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
87
|
+
|
|
88
|
+
reference_points: tuple[torch.FloatTensor] | None = None
|
|
90
89
|
|
|
91
90
|
|
|
92
91
|
@dataclass
|
|
@@ -95,7 +94,6 @@ class ConditionalDetrModelOutput(Seq2SeqModelOutput):
|
|
|
95
94
|
Output type of [`ConditionalDetrForObjectDetection`].
|
|
96
95
|
"""
|
|
97
96
|
)
|
|
98
|
-
# Copied from transformers.models.detr.modeling_detr.DetrObjectDetectionOutput with Detr->ConditionalDetr
|
|
99
97
|
class ConditionalDetrObjectDetectionOutput(ModelOutput):
|
|
100
98
|
r"""
|
|
101
99
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
|
@@ -119,18 +117,18 @@ class ConditionalDetrObjectDetectionOutput(ModelOutput):
|
|
|
119
117
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
|
120
118
|
"""
|
|
121
119
|
|
|
122
|
-
loss:
|
|
123
|
-
loss_dict:
|
|
124
|
-
logits:
|
|
125
|
-
pred_boxes:
|
|
126
|
-
auxiliary_outputs:
|
|
127
|
-
last_hidden_state:
|
|
128
|
-
decoder_hidden_states:
|
|
129
|
-
decoder_attentions:
|
|
130
|
-
cross_attentions:
|
|
131
|
-
encoder_last_hidden_state:
|
|
132
|
-
encoder_hidden_states:
|
|
133
|
-
encoder_attentions:
|
|
120
|
+
loss: torch.FloatTensor | None = None
|
|
121
|
+
loss_dict: dict | None = None
|
|
122
|
+
logits: torch.FloatTensor | None = None
|
|
123
|
+
pred_boxes: torch.FloatTensor | None = None
|
|
124
|
+
auxiliary_outputs: list[dict] | None = None
|
|
125
|
+
last_hidden_state: torch.FloatTensor | None = None
|
|
126
|
+
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
127
|
+
decoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
128
|
+
cross_attentions: tuple[torch.FloatTensor] | None = None
|
|
129
|
+
encoder_last_hidden_state: torch.FloatTensor | None = None
|
|
130
|
+
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
131
|
+
encoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
134
132
|
|
|
135
133
|
|
|
136
134
|
@dataclass
|
|
@@ -139,7 +137,6 @@ class ConditionalDetrObjectDetectionOutput(ModelOutput):
|
|
|
139
137
|
Output type of [`ConditionalDetrForSegmentation`].
|
|
140
138
|
"""
|
|
141
139
|
)
|
|
142
|
-
# Copied from transformers.models.detr.modeling_detr.DetrSegmentationOutput with Detr->ConditionalDetr
|
|
143
140
|
class ConditionalDetrSegmentationOutput(ModelOutput):
|
|
144
141
|
r"""
|
|
145
142
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
|
@@ -169,22 +166,21 @@ class ConditionalDetrSegmentationOutput(ModelOutput):
|
|
|
169
166
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
|
170
167
|
"""
|
|
171
168
|
|
|
172
|
-
loss:
|
|
173
|
-
loss_dict:
|
|
174
|
-
logits:
|
|
175
|
-
pred_boxes:
|
|
176
|
-
pred_masks:
|
|
177
|
-
auxiliary_outputs:
|
|
178
|
-
last_hidden_state:
|
|
179
|
-
decoder_hidden_states:
|
|
180
|
-
decoder_attentions:
|
|
181
|
-
cross_attentions:
|
|
182
|
-
encoder_last_hidden_state:
|
|
183
|
-
encoder_hidden_states:
|
|
184
|
-
encoder_attentions:
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->ConditionalDetr
|
|
169
|
+
loss: torch.FloatTensor | None = None
|
|
170
|
+
loss_dict: dict | None = None
|
|
171
|
+
logits: torch.FloatTensor | None = None
|
|
172
|
+
pred_boxes: torch.FloatTensor | None = None
|
|
173
|
+
pred_masks: torch.FloatTensor | None = None
|
|
174
|
+
auxiliary_outputs: list[dict] | None = None
|
|
175
|
+
last_hidden_state: torch.FloatTensor | None = None
|
|
176
|
+
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
177
|
+
decoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
178
|
+
cross_attentions: tuple[torch.FloatTensor] | None = None
|
|
179
|
+
encoder_last_hidden_state: torch.FloatTensor | None = None
|
|
180
|
+
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
181
|
+
encoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
182
|
+
|
|
183
|
+
|
|
188
184
|
class ConditionalDetrFrozenBatchNorm2d(nn.Module):
|
|
189
185
|
"""
|
|
190
186
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
@@ -224,7 +220,6 @@ class ConditionalDetrFrozenBatchNorm2d(nn.Module):
|
|
|
224
220
|
return x * scale + bias
|
|
225
221
|
|
|
226
222
|
|
|
227
|
-
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->ConditionalDetr
|
|
228
223
|
def replace_batch_norm(model):
|
|
229
224
|
r"""
|
|
230
225
|
Recursively replace all `torch.nn.BatchNorm2d` with `ConditionalDetrFrozenBatchNorm2d`.
|
|
@@ -249,7 +244,6 @@ def replace_batch_norm(model):
|
|
|
249
244
|
replace_batch_norm(module)
|
|
250
245
|
|
|
251
246
|
|
|
252
|
-
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->ConditionalDetr
|
|
253
247
|
class ConditionalDetrConvEncoder(nn.Module):
|
|
254
248
|
"""
|
|
255
249
|
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
|
|
@@ -263,47 +257,25 @@ class ConditionalDetrConvEncoder(nn.Module):
|
|
|
263
257
|
|
|
264
258
|
self.config = config
|
|
265
259
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
# We default to values which were previously hard-coded. This enables configurability from the config
|
|
269
|
-
# using backbone arguments, while keeping the default behavior the same.
|
|
270
|
-
requires_backends(self, ["timm"])
|
|
271
|
-
kwargs = getattr(config, "backbone_kwargs", {})
|
|
272
|
-
kwargs = {} if kwargs is None else kwargs.copy()
|
|
273
|
-
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
|
|
274
|
-
num_channels = kwargs.pop("in_chans", config.num_channels)
|
|
275
|
-
if config.dilation:
|
|
276
|
-
kwargs["output_stride"] = kwargs.get("output_stride", 16)
|
|
277
|
-
backbone = create_model(
|
|
278
|
-
config.backbone,
|
|
279
|
-
pretrained=config.use_pretrained_backbone,
|
|
280
|
-
features_only=True,
|
|
281
|
-
out_indices=out_indices,
|
|
282
|
-
in_chans=num_channels,
|
|
283
|
-
**kwargs,
|
|
284
|
-
)
|
|
285
|
-
else:
|
|
286
|
-
backbone = load_backbone(config)
|
|
260
|
+
backbone = load_backbone(config)
|
|
261
|
+
self.intermediate_channel_sizes = backbone.channels
|
|
287
262
|
|
|
288
263
|
# replace batch norm by frozen batch norm
|
|
289
264
|
with torch.no_grad():
|
|
290
265
|
replace_batch_norm(backbone)
|
|
291
|
-
self.model = backbone
|
|
292
|
-
self.intermediate_channel_sizes = (
|
|
293
|
-
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
|
|
294
|
-
)
|
|
295
266
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
267
|
+
# We used to load with timm library directly instead of the AutoBackbone API
|
|
268
|
+
# so we need to unwrap the `backbone._backbone` module to load weights without mismatch
|
|
269
|
+
is_timm_model = False
|
|
270
|
+
if hasattr(backbone, "_backbone"):
|
|
271
|
+
backbone = backbone._backbone
|
|
272
|
+
is_timm_model = True
|
|
273
|
+
self.model = backbone
|
|
303
274
|
|
|
275
|
+
backbone_model_type = config.backbone_config.model_type
|
|
304
276
|
if "resnet" in backbone_model_type:
|
|
305
277
|
for name, parameter in self.model.named_parameters():
|
|
306
|
-
if
|
|
278
|
+
if is_timm_model:
|
|
307
279
|
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
|
308
280
|
parameter.requires_grad_(False)
|
|
309
281
|
else:
|
|
@@ -312,7 +284,9 @@ class ConditionalDetrConvEncoder(nn.Module):
|
|
|
312
284
|
|
|
313
285
|
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
|
314
286
|
# send pixel_values through the model to get list of feature maps
|
|
315
|
-
features = self.model(pixel_values)
|
|
287
|
+
features = self.model(pixel_values)
|
|
288
|
+
if isinstance(features, dict):
|
|
289
|
+
features = features.feature_maps
|
|
316
290
|
|
|
317
291
|
out = []
|
|
318
292
|
for feature_map in features:
|
|
@@ -322,66 +296,58 @@ class ConditionalDetrConvEncoder(nn.Module):
|
|
|
322
296
|
return out
|
|
323
297
|
|
|
324
298
|
|
|
325
|
-
# Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->ConditionalDetr
|
|
326
|
-
class ConditionalDetrConvModel(nn.Module):
|
|
327
|
-
"""
|
|
328
|
-
This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
|
|
329
|
-
"""
|
|
330
|
-
|
|
331
|
-
def __init__(self, conv_encoder, position_embedding):
|
|
332
|
-
super().__init__()
|
|
333
|
-
self.conv_encoder = conv_encoder
|
|
334
|
-
self.position_embedding = position_embedding
|
|
335
|
-
|
|
336
|
-
def forward(self, pixel_values, pixel_mask):
|
|
337
|
-
# send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
|
|
338
|
-
out = self.conv_encoder(pixel_values, pixel_mask)
|
|
339
|
-
pos = []
|
|
340
|
-
for feature_map, mask in out:
|
|
341
|
-
# position encoding
|
|
342
|
-
pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
|
|
343
|
-
|
|
344
|
-
return out, pos
|
|
345
|
-
|
|
346
|
-
|
|
347
299
|
class ConditionalDetrSinePositionEmbedding(nn.Module):
|
|
348
300
|
"""
|
|
349
301
|
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
|
|
350
302
|
need paper, generalized to work on images.
|
|
351
303
|
"""
|
|
352
304
|
|
|
353
|
-
def __init__(
|
|
305
|
+
def __init__(
|
|
306
|
+
self,
|
|
307
|
+
num_position_features: int = 64,
|
|
308
|
+
temperature: int = 10000,
|
|
309
|
+
normalize: bool = False,
|
|
310
|
+
scale: float | None = None,
|
|
311
|
+
):
|
|
354
312
|
super().__init__()
|
|
355
|
-
self.embedding_dim = embedding_dim
|
|
356
|
-
self.temperature = temperature
|
|
357
|
-
self.normalize = normalize
|
|
358
313
|
if scale is not None and normalize is False:
|
|
359
314
|
raise ValueError("normalize should be True if scale is passed")
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
self.
|
|
315
|
+
self.num_position_features = num_position_features
|
|
316
|
+
self.temperature = temperature
|
|
317
|
+
self.normalize = normalize
|
|
318
|
+
self.scale = 2 * math.pi if scale is None else scale
|
|
363
319
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
320
|
+
@compile_compatible_method_lru_cache(maxsize=1)
|
|
321
|
+
def forward(
|
|
322
|
+
self,
|
|
323
|
+
shape: torch.Size,
|
|
324
|
+
device: torch.device | str,
|
|
325
|
+
dtype: torch.dtype,
|
|
326
|
+
mask: torch.Tensor | None = None,
|
|
327
|
+
) -> torch.Tensor:
|
|
328
|
+
if mask is None:
|
|
329
|
+
mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
|
|
330
|
+
y_embed = mask.cumsum(1, dtype=dtype)
|
|
331
|
+
x_embed = mask.cumsum(2, dtype=dtype)
|
|
369
332
|
if self.normalize:
|
|
370
|
-
|
|
371
|
-
|
|
333
|
+
eps = 1e-6
|
|
334
|
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
335
|
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
372
336
|
|
|
373
|
-
dim_t = torch.arange(self.
|
|
374
|
-
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.
|
|
337
|
+
dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
|
|
338
|
+
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
|
|
375
339
|
|
|
376
340
|
pos_x = x_embed[:, :, :, None] / dim_t
|
|
377
341
|
pos_y = y_embed[:, :, :, None] / dim_t
|
|
378
342
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
379
343
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
380
344
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
345
|
+
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
|
|
346
|
+
# expected by the encoder
|
|
347
|
+
pos = pos.flatten(2).permute(0, 2, 1)
|
|
381
348
|
return pos
|
|
382
349
|
|
|
383
350
|
|
|
384
|
-
# Copied from transformers.models.detr.modeling_detr.DetrLearnedPositionEmbedding with Detr->ConditionalDetr
|
|
385
351
|
class ConditionalDetrLearnedPositionEmbedding(nn.Module):
|
|
386
352
|
"""
|
|
387
353
|
This module learns positional embeddings up to a fixed maximum size.
|
|
@@ -392,354 +358,385 @@ class ConditionalDetrLearnedPositionEmbedding(nn.Module):
|
|
|
392
358
|
self.row_embeddings = nn.Embedding(50, embedding_dim)
|
|
393
359
|
self.column_embeddings = nn.Embedding(50, embedding_dim)
|
|
394
360
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
361
|
+
@compile_compatible_method_lru_cache(maxsize=1)
|
|
362
|
+
def forward(
|
|
363
|
+
self,
|
|
364
|
+
shape: torch.Size,
|
|
365
|
+
device: torch.device | str,
|
|
366
|
+
dtype: torch.dtype,
|
|
367
|
+
mask: torch.Tensor | None = None,
|
|
368
|
+
):
|
|
369
|
+
height, width = shape[-2:]
|
|
370
|
+
width_values = torch.arange(width, device=device)
|
|
371
|
+
height_values = torch.arange(height, device=device)
|
|
399
372
|
x_emb = self.column_embeddings(width_values)
|
|
400
373
|
y_emb = self.row_embeddings(height_values)
|
|
401
374
|
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
|
|
402
375
|
pos = pos.permute(2, 0, 1)
|
|
403
376
|
pos = pos.unsqueeze(0)
|
|
404
|
-
pos = pos.repeat(
|
|
377
|
+
pos = pos.repeat(shape[0], 1, 1, 1)
|
|
378
|
+
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
|
|
379
|
+
# expected by the encoder
|
|
380
|
+
pos = pos.flatten(2).permute(0, 2, 1)
|
|
405
381
|
return pos
|
|
406
382
|
|
|
407
383
|
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
384
|
+
def eager_attention_forward(
|
|
385
|
+
module: nn.Module,
|
|
386
|
+
query: torch.Tensor,
|
|
387
|
+
key: torch.Tensor,
|
|
388
|
+
value: torch.Tensor,
|
|
389
|
+
attention_mask: torch.Tensor | None,
|
|
390
|
+
scaling: float | None = None,
|
|
391
|
+
dropout: float = 0.0,
|
|
392
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
393
|
+
):
|
|
394
|
+
if scaling is None:
|
|
395
|
+
scaling = query.size(-1) ** -0.5
|
|
418
396
|
|
|
419
|
-
|
|
397
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
398
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
420
399
|
|
|
400
|
+
if attention_mask is not None:
|
|
401
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
402
|
+
attn_weights = attn_weights + attention_mask
|
|
421
403
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
scale = 2 * math.pi
|
|
425
|
-
dim = d_model // 2
|
|
426
|
-
dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
|
|
427
|
-
dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
|
|
428
|
-
x_embed = pos_tensor[:, :, 0] * scale
|
|
429
|
-
y_embed = pos_tensor[:, :, 1] * scale
|
|
430
|
-
pos_x = x_embed[:, :, None] / dim_t
|
|
431
|
-
pos_y = y_embed[:, :, None] / dim_t
|
|
432
|
-
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
433
|
-
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
434
|
-
pos = torch.cat((pos_y, pos_x), dim=2)
|
|
435
|
-
return pos.to(pos_tensor.dtype)
|
|
404
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
405
|
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
436
406
|
|
|
407
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
408
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
437
409
|
|
|
438
|
-
|
|
439
|
-
x = x.clamp(min=0, max=1)
|
|
440
|
-
x1 = x.clamp(min=eps)
|
|
441
|
-
x2 = (1 - x).clamp(min=eps)
|
|
442
|
-
return torch.log(x1 / x2)
|
|
410
|
+
return attn_output, attn_weights
|
|
443
411
|
|
|
444
412
|
|
|
445
|
-
|
|
446
|
-
class DetrAttention(nn.Module):
|
|
413
|
+
class ConditionalDetrSelfAttention(nn.Module):
|
|
447
414
|
"""
|
|
448
|
-
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
415
|
+
Multi-headed self-attention from 'Attention Is All You Need' paper.
|
|
449
416
|
|
|
450
|
-
|
|
417
|
+
In CONDITIONAL_DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
|
|
451
418
|
"""
|
|
452
419
|
|
|
453
420
|
def __init__(
|
|
454
421
|
self,
|
|
455
|
-
|
|
456
|
-
|
|
422
|
+
config: ConditionalDetrConfig,
|
|
423
|
+
hidden_size: int,
|
|
424
|
+
num_attention_heads: int,
|
|
457
425
|
dropout: float = 0.0,
|
|
458
426
|
bias: bool = True,
|
|
459
427
|
):
|
|
460
428
|
super().__init__()
|
|
461
|
-
self.
|
|
462
|
-
self.
|
|
463
|
-
self.dropout = dropout
|
|
464
|
-
self.head_dim = embed_dim // num_heads
|
|
465
|
-
if self.head_dim * num_heads != self.embed_dim:
|
|
466
|
-
raise ValueError(
|
|
467
|
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
468
|
-
f" {num_heads})."
|
|
469
|
-
)
|
|
429
|
+
self.config = config
|
|
430
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
470
431
|
self.scaling = self.head_dim**-0.5
|
|
432
|
+
self.attention_dropout = dropout
|
|
433
|
+
self.is_causal = False
|
|
471
434
|
|
|
472
|
-
self.k_proj = nn.Linear(
|
|
473
|
-
self.v_proj = nn.Linear(
|
|
474
|
-
self.q_proj = nn.Linear(
|
|
475
|
-
self.
|
|
476
|
-
|
|
477
|
-
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
|
478
|
-
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
479
|
-
|
|
480
|
-
def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]):
|
|
481
|
-
return tensor if object_queries is None else tensor + object_queries
|
|
435
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
436
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
437
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
438
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
482
439
|
|
|
483
440
|
def forward(
|
|
484
441
|
self,
|
|
485
442
|
hidden_states: torch.Tensor,
|
|
486
|
-
attention_mask:
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
"""
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
is_cross_attention = key_value_states is not None
|
|
496
|
-
batch_size, target_len, embed_dim = hidden_states.size()
|
|
497
|
-
|
|
498
|
-
# add position embeddings to the hidden states before projecting to queries and keys
|
|
499
|
-
if object_queries is not None:
|
|
500
|
-
hidden_states_original = hidden_states
|
|
501
|
-
hidden_states = self.with_pos_embed(hidden_states, object_queries)
|
|
502
|
-
|
|
503
|
-
# add key-value position embeddings to the key value states
|
|
504
|
-
if spatial_position_embeddings is not None:
|
|
505
|
-
key_value_states_original = key_value_states
|
|
506
|
-
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
|
|
507
|
-
|
|
508
|
-
# get query proj
|
|
509
|
-
query_states = self.q_proj(hidden_states) * self.scaling
|
|
510
|
-
# get key, value proj
|
|
511
|
-
if is_cross_attention:
|
|
512
|
-
# cross_attentions
|
|
513
|
-
key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
|
|
514
|
-
value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
|
|
515
|
-
else:
|
|
516
|
-
# self_attention
|
|
517
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
|
|
518
|
-
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
|
|
443
|
+
attention_mask: torch.Tensor | None = None,
|
|
444
|
+
position_embeddings: torch.Tensor | None = None,
|
|
445
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
446
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
447
|
+
"""
|
|
448
|
+
Position embeddings are added to both queries and keys (but not values).
|
|
449
|
+
"""
|
|
450
|
+
input_shape = hidden_states.shape[:-1]
|
|
451
|
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
519
452
|
|
|
520
|
-
|
|
521
|
-
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
|
|
522
|
-
key_states = key_states.view(*proj_shape)
|
|
523
|
-
value_states = value_states.view(*proj_shape)
|
|
453
|
+
query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
|
|
524
454
|
|
|
525
|
-
|
|
455
|
+
query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
456
|
+
key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
457
|
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
526
458
|
|
|
527
|
-
|
|
459
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
460
|
+
self.config._attn_implementation, eager_attention_forward
|
|
461
|
+
)
|
|
528
462
|
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
463
|
+
attn_output, attn_weights = attention_interface(
|
|
464
|
+
self,
|
|
465
|
+
query_states,
|
|
466
|
+
key_states,
|
|
467
|
+
value_states,
|
|
468
|
+
attention_mask,
|
|
469
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
470
|
+
scaling=self.scaling,
|
|
471
|
+
**kwargs,
|
|
472
|
+
)
|
|
534
473
|
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
|
539
|
-
f" {attention_mask.size()}"
|
|
540
|
-
)
|
|
541
|
-
if attention_mask.dtype == torch.bool:
|
|
542
|
-
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
|
|
543
|
-
attention_mask, -torch.inf
|
|
544
|
-
)
|
|
545
|
-
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
|
546
|
-
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
|
547
|
-
|
|
548
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
549
|
-
|
|
550
|
-
if output_attentions:
|
|
551
|
-
# this operation is a bit awkward, but it's required to
|
|
552
|
-
# make sure that attn_weights keeps its gradient.
|
|
553
|
-
# In order to do so, attn_weights have to reshaped
|
|
554
|
-
# twice and have to be reused in the following
|
|
555
|
-
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
|
556
|
-
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
|
557
|
-
else:
|
|
558
|
-
attn_weights_reshaped = None
|
|
474
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
475
|
+
attn_output = self.o_proj(attn_output)
|
|
476
|
+
return attn_output, attn_weights
|
|
559
477
|
|
|
560
|
-
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
561
478
|
|
|
562
|
-
|
|
479
|
+
class ConditionalDetrDecoderSelfAttention(nn.Module):
|
|
480
|
+
"""
|
|
481
|
+
Multi-headed self-attention for Conditional DETR decoder layers.
|
|
563
482
|
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
f" {attn_output.size()}"
|
|
568
|
-
)
|
|
483
|
+
This attention module handles separate content and position projections, which are then combined
|
|
484
|
+
before applying standard self-attention. Position embeddings are added to both queries and keys.
|
|
485
|
+
"""
|
|
569
486
|
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
487
|
+
def __init__(
|
|
488
|
+
self,
|
|
489
|
+
config: ConditionalDetrConfig,
|
|
490
|
+
hidden_size: int,
|
|
491
|
+
num_attention_heads: int,
|
|
492
|
+
dropout: float = 0.0,
|
|
493
|
+
):
|
|
494
|
+
super().__init__()
|
|
495
|
+
self.config = config
|
|
496
|
+
self.hidden_size = hidden_size
|
|
497
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
498
|
+
self.scaling = self.head_dim**-0.5
|
|
499
|
+
self.attention_dropout = dropout
|
|
500
|
+
self.is_causal = False
|
|
501
|
+
|
|
502
|
+
# Content and position projections
|
|
503
|
+
self.q_content_proj = nn.Linear(hidden_size, hidden_size)
|
|
504
|
+
self.q_pos_proj = nn.Linear(hidden_size, hidden_size)
|
|
505
|
+
self.k_content_proj = nn.Linear(hidden_size, hidden_size)
|
|
506
|
+
self.k_pos_proj = nn.Linear(hidden_size, hidden_size)
|
|
507
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size)
|
|
508
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size)
|
|
509
|
+
|
|
510
|
+
def forward(
|
|
511
|
+
self,
|
|
512
|
+
hidden_states: torch.Tensor,
|
|
513
|
+
query_position_embeddings: torch.Tensor,
|
|
514
|
+
attention_mask: torch.Tensor | None = None,
|
|
515
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
516
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
517
|
+
"""
|
|
518
|
+
Args:
|
|
519
|
+
hidden_states (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
520
|
+
Input hidden states from the decoder layer.
|
|
521
|
+
query_position_embeddings (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
522
|
+
Position embeddings for queries and keys. Required (unlike standard attention). Processed through
|
|
523
|
+
separate position projections (`q_pos_proj`, `k_pos_proj`) and added to content projections.
|
|
524
|
+
attention_mask (`torch.Tensor` of shape `(batch_size, 1, num_queries, num_queries)`, *optional*):
|
|
525
|
+
Attention mask to avoid attending to padding tokens.
|
|
526
|
+
"""
|
|
527
|
+
input_shape = hidden_states.shape[:-1]
|
|
528
|
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
529
|
+
|
|
530
|
+
query_states = (
|
|
531
|
+
(self.q_content_proj(hidden_states) + self.q_pos_proj(query_position_embeddings))
|
|
532
|
+
.view(hidden_shape)
|
|
533
|
+
.transpose(1, 2)
|
|
534
|
+
)
|
|
535
|
+
key_states = (
|
|
536
|
+
(self.k_content_proj(hidden_states) + self.k_pos_proj(query_position_embeddings))
|
|
537
|
+
.view(hidden_shape)
|
|
538
|
+
.transpose(1, 2)
|
|
539
|
+
)
|
|
540
|
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
573
541
|
|
|
574
|
-
|
|
542
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
543
|
+
self.config._attn_implementation, eager_attention_forward
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
attn_output, attn_weights = attention_interface(
|
|
547
|
+
self,
|
|
548
|
+
query_states,
|
|
549
|
+
key_states,
|
|
550
|
+
value_states,
|
|
551
|
+
attention_mask,
|
|
552
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
553
|
+
scaling=self.scaling,
|
|
554
|
+
**kwargs,
|
|
555
|
+
)
|
|
575
556
|
|
|
576
|
-
|
|
557
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
558
|
+
attn_output = self.o_proj(attn_output)
|
|
559
|
+
return attn_output, attn_weights
|
|
577
560
|
|
|
578
561
|
|
|
579
|
-
class
|
|
562
|
+
class ConditionalDetrDecoderCrossAttention(nn.Module):
|
|
580
563
|
"""
|
|
581
|
-
|
|
564
|
+
Multi-headed cross-attention for Conditional DETR decoder layers.
|
|
582
565
|
|
|
583
|
-
|
|
584
|
-
|
|
566
|
+
This attention module handles the special cross-attention logic in Conditional DETR:
|
|
567
|
+
- Separate content and position projections for queries and keys
|
|
568
|
+
- Concatenation of query sine embeddings with queries (doubling query dimension)
|
|
569
|
+
- Concatenation of key position embeddings with keys (doubling key dimension)
|
|
570
|
+
- Output dimension remains hidden_size despite doubled input dimensions
|
|
585
571
|
"""
|
|
586
572
|
|
|
587
573
|
def __init__(
|
|
588
574
|
self,
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
575
|
+
config: ConditionalDetrConfig,
|
|
576
|
+
hidden_size: int,
|
|
577
|
+
num_attention_heads: int,
|
|
592
578
|
dropout: float = 0.0,
|
|
593
|
-
bias: bool = True,
|
|
594
579
|
):
|
|
595
580
|
super().__init__()
|
|
596
|
-
self.
|
|
597
|
-
self.
|
|
598
|
-
self.
|
|
599
|
-
self.
|
|
600
|
-
self.
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
self.
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
self.
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
def _v_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
|
620
|
-
return tensor.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous()
|
|
581
|
+
self.config = config
|
|
582
|
+
self.hidden_size = hidden_size
|
|
583
|
+
self.num_attention_heads = num_attention_heads
|
|
584
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
585
|
+
self.attention_dropout = dropout
|
|
586
|
+
self.is_causal = False
|
|
587
|
+
|
|
588
|
+
# Content and position projections
|
|
589
|
+
self.q_content_proj = nn.Linear(hidden_size, hidden_size)
|
|
590
|
+
self.q_pos_proj = nn.Linear(hidden_size, hidden_size)
|
|
591
|
+
self.k_content_proj = nn.Linear(hidden_size, hidden_size)
|
|
592
|
+
self.k_pos_proj = nn.Linear(hidden_size, hidden_size)
|
|
593
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size)
|
|
594
|
+
self.q_pos_sine_proj = nn.Linear(hidden_size, hidden_size)
|
|
595
|
+
|
|
596
|
+
# Output projection: input is hidden_size * 2 (from concatenated q/k), output is hidden_size
|
|
597
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size)
|
|
598
|
+
|
|
599
|
+
# Compute scaling for expanded head_dim (q and k have doubled dimensions after concatenation)
|
|
600
|
+
# This matches the original Conditional DETR implementation where embed_dim * 2 is used
|
|
601
|
+
expanded_head_dim = (hidden_size * 2) // num_attention_heads
|
|
602
|
+
self.scaling = expanded_head_dim**-0.5
|
|
621
603
|
|
|
622
604
|
def forward(
|
|
623
605
|
self,
|
|
624
606
|
hidden_states: torch.Tensor,
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
607
|
+
encoder_hidden_states: torch.Tensor,
|
|
608
|
+
query_sine_embed: torch.Tensor,
|
|
609
|
+
encoder_position_embeddings: torch.Tensor,
|
|
610
|
+
query_position_embeddings: torch.Tensor | None = None,
|
|
611
|
+
attention_mask: torch.Tensor | None = None,
|
|
612
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
613
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
614
|
+
"""
|
|
615
|
+
Args:
|
|
616
|
+
hidden_states (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
617
|
+
Decoder hidden states (queries).
|
|
618
|
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, encoder_seq_len, hidden_size)`):
|
|
619
|
+
Encoder output hidden states (keys and values).
|
|
620
|
+
query_sine_embed (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
621
|
+
Sine position embeddings for queries. **Concatenated** (not added) with query content,
|
|
622
|
+
doubling the query dimension.
|
|
623
|
+
encoder_position_embeddings (`torch.Tensor` of shape `(batch_size, encoder_seq_len, hidden_size)`):
|
|
624
|
+
Position embeddings for keys. **Concatenated** (not added) with key content, doubling the key dimension.
|
|
625
|
+
query_position_embeddings (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
626
|
+
Additional position embeddings. When provided (first layer only), **added** to query content
|
|
627
|
+
before concatenation with `query_sine_embed`. Also causes `encoder_position_embeddings` to be
|
|
628
|
+
added to key content before concatenation.
|
|
629
|
+
attention_mask (`torch.Tensor` of shape `(batch_size, 1, num_queries, encoder_seq_len)`, *optional*):
|
|
630
|
+
Attention mask to avoid attending to padding tokens.
|
|
631
|
+
"""
|
|
632
|
+
query_input_shape = hidden_states.shape[:-1]
|
|
633
|
+
kv_input_shape = encoder_hidden_states.shape[:-1]
|
|
634
|
+
query_hidden_shape = (*query_input_shape, self.num_attention_heads, self.head_dim)
|
|
635
|
+
kv_hidden_shape = (*kv_input_shape, self.num_attention_heads, self.head_dim)
|
|
636
|
+
|
|
637
|
+
# Apply content and position projections
|
|
638
|
+
query_input = self.q_content_proj(hidden_states)
|
|
639
|
+
key_input = self.k_content_proj(encoder_hidden_states)
|
|
640
|
+
value_states = self.v_proj(encoder_hidden_states)
|
|
641
|
+
key_pos = self.k_pos_proj(encoder_position_embeddings)
|
|
642
|
+
|
|
643
|
+
# Combine content and position embeddings
|
|
644
|
+
if query_position_embeddings is not None:
|
|
645
|
+
query_input = query_input + self.q_pos_proj(query_position_embeddings)
|
|
646
|
+
key_input = key_input + key_pos
|
|
647
|
+
|
|
648
|
+
# Reshape and concatenate position embeddings (doubling head_dim)
|
|
649
|
+
query_input = query_input.view(query_hidden_shape)
|
|
650
|
+
key_input = key_input.view(kv_hidden_shape)
|
|
651
|
+
query_sine_embed = self.q_pos_sine_proj(query_sine_embed).view(query_hidden_shape)
|
|
652
|
+
key_pos = key_pos.view(kv_hidden_shape)
|
|
653
|
+
|
|
654
|
+
query_states = torch.cat([query_input, query_sine_embed], dim=-1).view(*query_input_shape, -1)
|
|
655
|
+
key_states = torch.cat([key_input, key_pos], dim=-1).view(*kv_input_shape, -1)
|
|
656
|
+
|
|
657
|
+
# Reshape for attention computation
|
|
658
|
+
expanded_head_dim = query_states.shape[-1] // self.num_attention_heads
|
|
659
|
+
query_states = query_states.view(*query_input_shape, self.num_attention_heads, expanded_head_dim).transpose(
|
|
660
|
+
1, 2
|
|
661
|
+
)
|
|
662
|
+
key_states = key_states.view(*kv_input_shape, self.num_attention_heads, expanded_head_dim).transpose(1, 2)
|
|
663
|
+
value_states = value_states.view(kv_hidden_shape).transpose(1, 2)
|
|
680
664
|
|
|
681
|
-
|
|
665
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
666
|
+
self.config._attn_implementation, eager_attention_forward
|
|
667
|
+
)
|
|
682
668
|
|
|
683
|
-
attn_output =
|
|
669
|
+
attn_output, attn_weights = attention_interface(
|
|
670
|
+
self,
|
|
671
|
+
query_states,
|
|
672
|
+
key_states,
|
|
673
|
+
value_states,
|
|
674
|
+
attention_mask,
|
|
675
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
676
|
+
scaling=self.scaling,
|
|
677
|
+
**kwargs,
|
|
678
|
+
)
|
|
684
679
|
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
f" {attn_output.size()}"
|
|
689
|
-
)
|
|
680
|
+
attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
|
|
681
|
+
attn_output = self.o_proj(attn_output)
|
|
682
|
+
return attn_output, attn_weights
|
|
690
683
|
|
|
691
|
-
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.v_head_dim)
|
|
692
|
-
attn_output = attn_output.transpose(1, 2)
|
|
693
|
-
attn_output = attn_output.reshape(batch_size, target_len, self.out_dim)
|
|
694
684
|
|
|
695
|
-
|
|
685
|
+
class ConditionalDetrMLP(nn.Module):
|
|
686
|
+
def __init__(self, config: ConditionalDetrConfig, hidden_size: int, intermediate_size: int):
|
|
687
|
+
super().__init__()
|
|
688
|
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
|
689
|
+
self.fc2 = nn.Linear(intermediate_size, hidden_size)
|
|
690
|
+
self.activation_fn = ACT2FN[config.activation_function]
|
|
691
|
+
self.activation_dropout = config.activation_dropout
|
|
692
|
+
self.dropout = config.dropout
|
|
696
693
|
|
|
697
|
-
|
|
694
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
695
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
696
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
697
|
+
hidden_states = self.fc2(hidden_states)
|
|
698
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
699
|
+
return hidden_states
|
|
698
700
|
|
|
699
701
|
|
|
700
|
-
|
|
701
|
-
class ConditionalDetrEncoderLayer(nn.Module):
|
|
702
|
+
class ConditionalDetrEncoderLayer(GradientCheckpointingLayer):
|
|
702
703
|
def __init__(self, config: ConditionalDetrConfig):
|
|
703
704
|
super().__init__()
|
|
704
|
-
self.
|
|
705
|
-
self.self_attn =
|
|
706
|
-
|
|
707
|
-
|
|
705
|
+
self.hidden_size = config.d_model
|
|
706
|
+
self.self_attn = ConditionalDetrSelfAttention(
|
|
707
|
+
config=config,
|
|
708
|
+
hidden_size=self.hidden_size,
|
|
709
|
+
num_attention_heads=config.encoder_attention_heads,
|
|
708
710
|
dropout=config.attention_dropout,
|
|
709
711
|
)
|
|
710
|
-
self.self_attn_layer_norm = nn.LayerNorm(self.
|
|
712
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
711
713
|
self.dropout = config.dropout
|
|
712
|
-
self.
|
|
713
|
-
self.
|
|
714
|
-
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
|
715
|
-
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
|
716
|
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
714
|
+
self.mlp = ConditionalDetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
|
|
715
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
717
716
|
|
|
718
717
|
def forward(
|
|
719
718
|
self,
|
|
720
719
|
hidden_states: torch.Tensor,
|
|
721
720
|
attention_mask: torch.Tensor,
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
):
|
|
721
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
722
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
723
|
+
) -> torch.Tensor:
|
|
725
724
|
"""
|
|
726
725
|
Args:
|
|
727
|
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len,
|
|
726
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
728
727
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
729
728
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
730
729
|
values.
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
735
|
-
returned tensors for more detail.
|
|
730
|
+
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
731
|
+
Spatial position embeddings (2D positional encodings of image locations), to be added to both
|
|
732
|
+
the queries and keys in self-attention (but not to values).
|
|
736
733
|
"""
|
|
737
734
|
residual = hidden_states
|
|
738
|
-
hidden_states,
|
|
735
|
+
hidden_states, _ = self.self_attn(
|
|
739
736
|
hidden_states=hidden_states,
|
|
740
737
|
attention_mask=attention_mask,
|
|
741
|
-
|
|
742
|
-
|
|
738
|
+
position_embeddings=spatial_position_embeddings,
|
|
739
|
+
**kwargs,
|
|
743
740
|
)
|
|
744
741
|
|
|
745
742
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
@@ -747,12 +744,7 @@ class ConditionalDetrEncoderLayer(nn.Module):
|
|
|
747
744
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
748
745
|
|
|
749
746
|
residual = hidden_states
|
|
750
|
-
hidden_states = self.
|
|
751
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
752
|
-
|
|
753
|
-
hidden_states = self.fc2(hidden_states)
|
|
754
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
755
|
-
|
|
747
|
+
hidden_states = self.mlp(hidden_states)
|
|
756
748
|
hidden_states = residual + hidden_states
|
|
757
749
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
758
750
|
|
|
@@ -761,80 +753,55 @@ class ConditionalDetrEncoderLayer(nn.Module):
|
|
|
761
753
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
762
754
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
763
755
|
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
if output_attentions:
|
|
767
|
-
outputs += (attn_weights,)
|
|
768
|
-
|
|
769
|
-
return outputs
|
|
756
|
+
return hidden_states
|
|
770
757
|
|
|
771
758
|
|
|
772
759
|
class ConditionalDetrDecoderLayer(GradientCheckpointingLayer):
|
|
773
760
|
def __init__(self, config: ConditionalDetrConfig):
|
|
774
761
|
super().__init__()
|
|
775
|
-
self.
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
self.sa_qpos_proj = nn.Linear(d_model, d_model)
|
|
781
|
-
self.sa_kcontent_proj = nn.Linear(d_model, d_model)
|
|
782
|
-
self.sa_kpos_proj = nn.Linear(d_model, d_model)
|
|
783
|
-
self.sa_v_proj = nn.Linear(d_model, d_model)
|
|
784
|
-
|
|
785
|
-
self.self_attn = ConditionalDetrAttention(
|
|
786
|
-
embed_dim=self.embed_dim,
|
|
787
|
-
out_dim=self.embed_dim,
|
|
788
|
-
num_heads=config.decoder_attention_heads,
|
|
762
|
+
self.hidden_size = config.d_model
|
|
763
|
+
self.self_attn = ConditionalDetrDecoderSelfAttention(
|
|
764
|
+
config=config,
|
|
765
|
+
hidden_size=self.hidden_size,
|
|
766
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
789
767
|
dropout=config.attention_dropout,
|
|
790
768
|
)
|
|
791
769
|
self.dropout = config.dropout
|
|
792
|
-
self.activation_fn = ACT2FN[config.activation_function]
|
|
793
|
-
self.activation_dropout = config.activation_dropout
|
|
794
|
-
|
|
795
|
-
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
796
|
-
|
|
797
|
-
# Decoder Cross-Attention projections
|
|
798
|
-
self.ca_qcontent_proj = nn.Linear(d_model, d_model)
|
|
799
|
-
self.ca_qpos_proj = nn.Linear(d_model, d_model)
|
|
800
|
-
self.ca_kcontent_proj = nn.Linear(d_model, d_model)
|
|
801
|
-
self.ca_kpos_proj = nn.Linear(d_model, d_model)
|
|
802
|
-
self.ca_v_proj = nn.Linear(d_model, d_model)
|
|
803
|
-
self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
|
|
804
770
|
|
|
805
|
-
self.
|
|
806
|
-
|
|
771
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
772
|
+
self.encoder_attn = ConditionalDetrDecoderCrossAttention(
|
|
773
|
+
config=config,
|
|
774
|
+
hidden_size=self.hidden_size,
|
|
775
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
776
|
+
dropout=config.attention_dropout,
|
|
807
777
|
)
|
|
808
|
-
self.encoder_attn_layer_norm = nn.LayerNorm(self.
|
|
809
|
-
self.
|
|
810
|
-
self.
|
|
811
|
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
812
|
-
self.nhead = config.decoder_attention_heads
|
|
778
|
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
779
|
+
self.mlp = ConditionalDetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
|
|
780
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
813
781
|
|
|
814
782
|
def forward(
|
|
815
783
|
self,
|
|
816
784
|
hidden_states: torch.Tensor,
|
|
817
|
-
attention_mask:
|
|
818
|
-
|
|
819
|
-
query_position_embeddings:
|
|
820
|
-
query_sine_embed:
|
|
821
|
-
encoder_hidden_states:
|
|
822
|
-
encoder_attention_mask:
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
):
|
|
785
|
+
attention_mask: torch.Tensor | None = None,
|
|
786
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
787
|
+
query_position_embeddings: torch.Tensor | None = None,
|
|
788
|
+
query_sine_embed: torch.Tensor | None = None,
|
|
789
|
+
encoder_hidden_states: torch.Tensor | None = None,
|
|
790
|
+
encoder_attention_mask: torch.Tensor | None = None,
|
|
791
|
+
is_first: bool | None = False,
|
|
792
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
793
|
+
) -> torch.Tensor:
|
|
826
794
|
"""
|
|
827
795
|
Args:
|
|
828
796
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
|
829
797
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
830
798
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
831
799
|
values.
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
in the cross-attention layer.
|
|
800
|
+
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
801
|
+
Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
|
|
835
802
|
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
836
803
|
object_queries that are added to the queries and keys
|
|
837
|
-
|
|
804
|
+
in the self-attention layer.
|
|
838
805
|
encoder_hidden_states (`torch.FloatTensor`):
|
|
839
806
|
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
|
840
807
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
@@ -846,108 +813,49 @@ class ConditionalDetrDecoderLayer(GradientCheckpointingLayer):
|
|
|
846
813
|
"""
|
|
847
814
|
residual = hidden_states
|
|
848
815
|
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
q_content = self.sa_qcontent_proj(
|
|
853
|
-
hidden_states
|
|
854
|
-
) # target is the input of the first decoder layer. zero by default.
|
|
855
|
-
q_pos = self.sa_qpos_proj(query_position_embeddings)
|
|
856
|
-
k_content = self.sa_kcontent_proj(hidden_states)
|
|
857
|
-
k_pos = self.sa_kpos_proj(query_position_embeddings)
|
|
858
|
-
v = self.sa_v_proj(hidden_states)
|
|
859
|
-
|
|
860
|
-
_, num_queries, n_model = q_content.shape
|
|
861
|
-
|
|
862
|
-
q = q_content + q_pos
|
|
863
|
-
k = k_content + k_pos
|
|
864
|
-
hidden_states, self_attn_weights = self.self_attn(
|
|
865
|
-
hidden_states=q,
|
|
816
|
+
hidden_states, _ = self.self_attn(
|
|
817
|
+
hidden_states=hidden_states,
|
|
818
|
+
query_position_embeddings=query_position_embeddings,
|
|
866
819
|
attention_mask=attention_mask,
|
|
867
|
-
|
|
868
|
-
value_states=v,
|
|
869
|
-
output_attentions=output_attentions,
|
|
820
|
+
**kwargs,
|
|
870
821
|
)
|
|
871
|
-
# ============ End of Self-Attention =============
|
|
872
822
|
|
|
873
823
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
874
824
|
hidden_states = residual + hidden_states
|
|
875
825
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
876
826
|
|
|
877
|
-
# ========== Begin of Cross-Attention =============
|
|
878
|
-
# Apply projections here
|
|
879
|
-
# shape: num_queries x batch_size x 256
|
|
880
|
-
q_content = self.ca_qcontent_proj(hidden_states)
|
|
881
|
-
k_content = self.ca_kcontent_proj(encoder_hidden_states)
|
|
882
|
-
v = self.ca_v_proj(encoder_hidden_states)
|
|
883
|
-
|
|
884
|
-
batch_size, num_queries, n_model = q_content.shape
|
|
885
|
-
_, source_len, _ = k_content.shape
|
|
886
|
-
|
|
887
|
-
k_pos = self.ca_kpos_proj(object_queries)
|
|
888
|
-
|
|
889
|
-
# For the first decoder layer, we concatenate the positional embedding predicted from
|
|
890
|
-
# the object query (the positional embedding) into the original query (key) in DETR.
|
|
891
|
-
if is_first:
|
|
892
|
-
q_pos = self.ca_qpos_proj(query_position_embeddings)
|
|
893
|
-
q = q_content + q_pos
|
|
894
|
-
k = k_content + k_pos
|
|
895
|
-
else:
|
|
896
|
-
q = q_content
|
|
897
|
-
k = k_content
|
|
898
|
-
|
|
899
|
-
q = q.view(batch_size, num_queries, self.nhead, n_model // self.nhead)
|
|
900
|
-
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
|
|
901
|
-
query_sine_embed = query_sine_embed.view(batch_size, num_queries, self.nhead, n_model // self.nhead)
|
|
902
|
-
q = torch.cat([q, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2)
|
|
903
|
-
k = k.view(batch_size, source_len, self.nhead, n_model // self.nhead)
|
|
904
|
-
k_pos = k_pos.view(batch_size, source_len, self.nhead, n_model // self.nhead)
|
|
905
|
-
k = torch.cat([k, k_pos], dim=3).view(batch_size, source_len, n_model * 2)
|
|
906
|
-
|
|
907
|
-
# Cross-Attention Block
|
|
908
|
-
cross_attn_weights = None
|
|
909
827
|
if encoder_hidden_states is not None:
|
|
910
828
|
residual = hidden_states
|
|
911
829
|
|
|
912
|
-
hidden_states,
|
|
913
|
-
hidden_states=
|
|
830
|
+
hidden_states, _ = self.encoder_attn(
|
|
831
|
+
hidden_states=hidden_states,
|
|
832
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
914
833
|
attention_mask=encoder_attention_mask,
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
834
|
+
query_sine_embed=query_sine_embed,
|
|
835
|
+
encoder_position_embeddings=spatial_position_embeddings,
|
|
836
|
+
# Only pass query_position_embeddings for the first layer
|
|
837
|
+
query_position_embeddings=query_position_embeddings if is_first else None,
|
|
838
|
+
**kwargs,
|
|
918
839
|
)
|
|
919
840
|
|
|
920
841
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
921
842
|
hidden_states = residual + hidden_states
|
|
922
843
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
923
844
|
|
|
924
|
-
# ============ End of Cross-Attention =============
|
|
925
|
-
|
|
926
845
|
# Fully Connected
|
|
927
846
|
residual = hidden_states
|
|
928
|
-
hidden_states = self.
|
|
929
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
930
|
-
hidden_states = self.fc2(hidden_states)
|
|
931
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
847
|
+
hidden_states = self.mlp(hidden_states)
|
|
932
848
|
hidden_states = residual + hidden_states
|
|
933
849
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
934
850
|
|
|
935
|
-
|
|
851
|
+
return hidden_states
|
|
936
852
|
|
|
937
|
-
if output_attentions:
|
|
938
|
-
outputs += (self_attn_weights, cross_attn_weights)
|
|
939
853
|
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with DetrMLPPredictionHead->MLP
|
|
944
|
-
class MLP(nn.Module):
|
|
854
|
+
class ConditionalDetrMLPPredictionHead(nn.Module):
|
|
945
855
|
"""
|
|
946
856
|
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
|
947
857
|
height and width of a bounding box w.r.t. an image.
|
|
948
858
|
|
|
949
|
-
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
950
|
-
|
|
951
859
|
"""
|
|
952
860
|
|
|
953
861
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
@@ -962,29 +870,202 @@ class MLP(nn.Module):
|
|
|
962
870
|
return x
|
|
963
871
|
|
|
964
872
|
|
|
873
|
+
class ConditionalDetrConvBlock(nn.Module):
|
|
874
|
+
"""Basic conv block: Conv3x3 -> GroupNorm -> Activation."""
|
|
875
|
+
|
|
876
|
+
def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"):
|
|
877
|
+
super().__init__()
|
|
878
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
879
|
+
self.norm = nn.GroupNorm(min(8, out_channels), out_channels)
|
|
880
|
+
self.activation = ACT2FN[activation]
|
|
881
|
+
|
|
882
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
883
|
+
return self.activation(self.norm(self.conv(x)))
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
class ConditionalDetrFPNFusionStage(nn.Module):
|
|
887
|
+
"""Single FPN fusion stage combining low-resolution features with high-resolution FPN features."""
|
|
888
|
+
|
|
889
|
+
def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"):
|
|
890
|
+
super().__init__()
|
|
891
|
+
self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1)
|
|
892
|
+
self.refine = ConditionalDetrConvBlock(current_channels, output_channels, activation)
|
|
893
|
+
|
|
894
|
+
def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor:
|
|
895
|
+
"""
|
|
896
|
+
Args:
|
|
897
|
+
features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in)
|
|
898
|
+
fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out)
|
|
899
|
+
|
|
900
|
+
Returns:
|
|
901
|
+
Fused and refined features, shape (B*Q, output_channels, H_out, W_out)
|
|
902
|
+
"""
|
|
903
|
+
fpn_features = self.fpn_adapter(fpn_features)
|
|
904
|
+
features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest")
|
|
905
|
+
return self.refine(fpn_features + features)
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
class ConditionalDetrMaskHeadSmallConv(nn.Module):
|
|
909
|
+
"""
|
|
910
|
+
Segmentation mask head that generates per-query masks using FPN-based progressive upsampling.
|
|
911
|
+
|
|
912
|
+
Combines attention maps (spatial localization) with encoder features (semantics) and progressively
|
|
913
|
+
upsamples through multiple scales, fusing with FPN features for high-resolution detail.
|
|
914
|
+
"""
|
|
915
|
+
|
|
916
|
+
def __init__(
|
|
917
|
+
self,
|
|
918
|
+
input_channels: int,
|
|
919
|
+
fpn_channels: list[int],
|
|
920
|
+
hidden_size: int,
|
|
921
|
+
activation_function: str = "relu",
|
|
922
|
+
):
|
|
923
|
+
super().__init__()
|
|
924
|
+
if input_channels % 8 != 0:
|
|
925
|
+
raise ValueError(f"input_channels must be divisible by 8, got {input_channels}")
|
|
926
|
+
|
|
927
|
+
self.conv1 = ConditionalDetrConvBlock(input_channels, input_channels, activation_function)
|
|
928
|
+
self.conv2 = ConditionalDetrConvBlock(input_channels, hidden_size // 2, activation_function)
|
|
929
|
+
|
|
930
|
+
# Progressive channel reduction: /2 -> /4 -> /8 -> /16
|
|
931
|
+
self.fpn_stages = nn.ModuleList(
|
|
932
|
+
[
|
|
933
|
+
ConditionalDetrFPNFusionStage(
|
|
934
|
+
fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function
|
|
935
|
+
),
|
|
936
|
+
ConditionalDetrFPNFusionStage(
|
|
937
|
+
fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function
|
|
938
|
+
),
|
|
939
|
+
ConditionalDetrFPNFusionStage(
|
|
940
|
+
fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function
|
|
941
|
+
),
|
|
942
|
+
]
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1)
|
|
946
|
+
|
|
947
|
+
def forward(
|
|
948
|
+
self,
|
|
949
|
+
features: torch.Tensor,
|
|
950
|
+
attention_masks: torch.Tensor,
|
|
951
|
+
fpn_features: list[torch.Tensor],
|
|
952
|
+
) -> torch.Tensor:
|
|
953
|
+
"""
|
|
954
|
+
Args:
|
|
955
|
+
features: Encoder output features, shape (batch_size, hidden_size, H, W)
|
|
956
|
+
attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W)
|
|
957
|
+
fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W)
|
|
958
|
+
|
|
959
|
+
Returns:
|
|
960
|
+
Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W)
|
|
961
|
+
"""
|
|
962
|
+
num_queries = attention_masks.shape[1]
|
|
963
|
+
|
|
964
|
+
# Expand to (batch_size * num_queries) dimension
|
|
965
|
+
features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1)
|
|
966
|
+
attention_masks = attention_masks.flatten(0, 1)
|
|
967
|
+
fpn_features = [
|
|
968
|
+
fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features
|
|
969
|
+
]
|
|
970
|
+
|
|
971
|
+
hidden_states = torch.cat([features, attention_masks], dim=1)
|
|
972
|
+
hidden_states = self.conv1(hidden_states)
|
|
973
|
+
hidden_states = self.conv2(hidden_states)
|
|
974
|
+
|
|
975
|
+
for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features):
|
|
976
|
+
hidden_states = fpn_stage(hidden_states, fpn_feat)
|
|
977
|
+
|
|
978
|
+
return self.output_conv(hidden_states)
|
|
979
|
+
|
|
980
|
+
|
|
981
|
+
class ConditionalDetrMHAttentionMap(nn.Module):
|
|
982
|
+
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
|
983
|
+
|
|
984
|
+
def __init__(
|
|
985
|
+
self,
|
|
986
|
+
hidden_size: int,
|
|
987
|
+
num_attention_heads: int,
|
|
988
|
+
dropout: float = 0.0,
|
|
989
|
+
bias: bool = True,
|
|
990
|
+
):
|
|
991
|
+
super().__init__()
|
|
992
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
993
|
+
self.scaling = self.head_dim**-0.5
|
|
994
|
+
self.attention_dropout = dropout
|
|
995
|
+
|
|
996
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
997
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
998
|
+
|
|
999
|
+
def forward(
|
|
1000
|
+
self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None
|
|
1001
|
+
):
|
|
1002
|
+
query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim)
|
|
1003
|
+
key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:])
|
|
1004
|
+
|
|
1005
|
+
query_states = self.q_proj(query_states).view(query_hidden_shape)
|
|
1006
|
+
key_states = nn.functional.conv2d(
|
|
1007
|
+
key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias
|
|
1008
|
+
).view(key_hidden_shape)
|
|
1009
|
+
|
|
1010
|
+
batch_size, num_queries, num_heads, head_dim = query_states.shape
|
|
1011
|
+
_, _, _, height, width = key_states.shape
|
|
1012
|
+
query_shape = (batch_size * num_heads, num_queries, head_dim)
|
|
1013
|
+
key_shape = (batch_size * num_heads, height * width, head_dim)
|
|
1014
|
+
attn_weights_shape = (batch_size, num_heads, num_queries, height, width)
|
|
1015
|
+
|
|
1016
|
+
query = query_states.transpose(1, 2).contiguous().view(query_shape)
|
|
1017
|
+
key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape)
|
|
1018
|
+
|
|
1019
|
+
attn_weights = (
|
|
1020
|
+
(torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2)
|
|
1021
|
+
)
|
|
1022
|
+
|
|
1023
|
+
if attention_mask is not None:
|
|
1024
|
+
attn_weights = attn_weights + attention_mask
|
|
1025
|
+
|
|
1026
|
+
attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size())
|
|
1027
|
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
1028
|
+
|
|
1029
|
+
return attn_weights
|
|
1030
|
+
|
|
1031
|
+
|
|
965
1032
|
@auto_docstring
|
|
966
|
-
# Copied from transformers.models.detr.modeling_detr.DetrPreTrainedModel with Detr->ConditionalDetr
|
|
967
1033
|
class ConditionalDetrPreTrainedModel(PreTrainedModel):
|
|
968
1034
|
config: ConditionalDetrConfig
|
|
969
1035
|
base_model_prefix = "model"
|
|
970
1036
|
main_input_name = "pixel_values"
|
|
971
1037
|
input_modalities = ("image",)
|
|
972
1038
|
_no_split_modules = [r"ConditionalDetrConvEncoder", r"ConditionalDetrEncoderLayer", r"ConditionalDetrDecoderLayer"]
|
|
1039
|
+
supports_gradient_checkpointing = True
|
|
1040
|
+
_supports_sdpa = True
|
|
1041
|
+
_supports_flash_attn = True
|
|
1042
|
+
_supports_attention_backend = True
|
|
1043
|
+
_supports_flex_attn = True # Uses create_bidirectional_masks for attention masking
|
|
1044
|
+
_keys_to_ignore_on_load_unexpected = [
|
|
1045
|
+
r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
|
|
1046
|
+
]
|
|
973
1047
|
|
|
974
1048
|
@torch.no_grad()
|
|
975
1049
|
def _init_weights(self, module):
|
|
976
1050
|
std = self.config.init_std
|
|
977
1051
|
xavier_std = self.config.init_xavier_std
|
|
978
1052
|
|
|
979
|
-
if isinstance(module,
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
1053
|
+
if isinstance(module, ConditionalDetrMaskHeadSmallConv):
|
|
1054
|
+
# ConditionalDetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers
|
|
1055
|
+
for m in module.modules():
|
|
1056
|
+
if isinstance(m, nn.Conv2d):
|
|
1057
|
+
init.kaiming_uniform_(m.weight, a=1)
|
|
1058
|
+
if m.bias is not None:
|
|
1059
|
+
init.constant_(m.bias, 0)
|
|
1060
|
+
elif isinstance(module, ConditionalDetrMHAttentionMap):
|
|
1061
|
+
init.zeros_(module.k_proj.bias)
|
|
1062
|
+
init.zeros_(module.q_proj.bias)
|
|
1063
|
+
init.xavier_uniform_(module.k_proj.weight, gain=xavier_std)
|
|
1064
|
+
init.xavier_uniform_(module.q_proj.weight, gain=xavier_std)
|
|
984
1065
|
elif isinstance(module, ConditionalDetrLearnedPositionEmbedding):
|
|
985
1066
|
init.uniform_(module.row_embeddings.weight)
|
|
986
1067
|
init.uniform_(module.column_embeddings.weight)
|
|
987
|
-
|
|
1068
|
+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
988
1069
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
989
1070
|
if module.bias is not None:
|
|
990
1071
|
init.zeros_(module.bias)
|
|
@@ -998,50 +1079,38 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
|
|
|
998
1079
|
init.zeros_(module.bias)
|
|
999
1080
|
|
|
1000
1081
|
|
|
1001
|
-
# Copied from transformers.models.detr.modeling_detr.DetrEncoder with Detr->ConditionalDetr,DETR->ConditionalDETR
|
|
1002
1082
|
class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
|
|
1003
1083
|
"""
|
|
1004
|
-
Transformer encoder
|
|
1005
|
-
[`ConditionalDetrEncoderLayer`].
|
|
1006
|
-
|
|
1007
|
-
The encoder updates the flattened feature map through multiple self-attention layers.
|
|
1008
|
-
|
|
1009
|
-
Small tweak for ConditionalDETR:
|
|
1010
|
-
|
|
1011
|
-
- object_queries are added to the forward pass.
|
|
1084
|
+
Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of
|
|
1085
|
+
[`ConditionalDetrEncoderLayer`] modules.
|
|
1012
1086
|
|
|
1013
1087
|
Args:
|
|
1014
|
-
config:
|
|
1088
|
+
config (`ConditionalDetrConfig`): Model configuration object.
|
|
1015
1089
|
"""
|
|
1016
1090
|
|
|
1091
|
+
_can_record_outputs = {"hidden_states": ConditionalDetrEncoderLayer, "attentions": ConditionalDetrSelfAttention}
|
|
1092
|
+
|
|
1017
1093
|
def __init__(self, config: ConditionalDetrConfig):
|
|
1018
1094
|
super().__init__(config)
|
|
1019
1095
|
|
|
1020
1096
|
self.dropout = config.dropout
|
|
1021
|
-
self.layerdrop = config.encoder_layerdrop
|
|
1022
|
-
|
|
1023
1097
|
self.layers = nn.ModuleList([ConditionalDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
1024
1098
|
|
|
1025
|
-
# in the original ConditionalDETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
|
|
1026
|
-
|
|
1027
1099
|
# Initialize weights and apply final processing
|
|
1028
1100
|
self.post_init()
|
|
1029
1101
|
|
|
1102
|
+
@check_model_inputs()
|
|
1030
1103
|
def forward(
|
|
1031
1104
|
self,
|
|
1032
1105
|
inputs_embeds=None,
|
|
1033
1106
|
attention_mask=None,
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
return_dict=None,
|
|
1038
|
-
**kwargs,
|
|
1039
|
-
):
|
|
1107
|
+
spatial_position_embeddings=None,
|
|
1108
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1109
|
+
) -> BaseModelOutput:
|
|
1040
1110
|
r"""
|
|
1041
1111
|
Args:
|
|
1042
1112
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
1043
1113
|
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
|
1044
|
-
|
|
1045
1114
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1046
1115
|
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
|
1047
1116
|
|
|
@@ -1049,69 +1118,44 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
|
|
|
1049
1118
|
- 0 for pixel features that are padding (i.e. **masked**).
|
|
1050
1119
|
|
|
1051
1120
|
[What are attention masks?](../glossary#attention-mask)
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
Object queries that are added to the queries in each self-attention layer.
|
|
1055
|
-
|
|
1056
|
-
output_attentions (`bool`, *optional*):
|
|
1057
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
1058
|
-
returned tensors for more detail.
|
|
1059
|
-
output_hidden_states (`bool`, *optional*):
|
|
1060
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
1061
|
-
for more detail.
|
|
1062
|
-
return_dict (`bool`, *optional*):
|
|
1063
|
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
1121
|
+
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
1122
|
+
Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
|
|
1064
1123
|
"""
|
|
1065
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1066
|
-
output_hidden_states = (
|
|
1067
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1068
|
-
)
|
|
1069
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1070
|
-
|
|
1071
1124
|
hidden_states = inputs_embeds
|
|
1072
1125
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
1073
1126
|
|
|
1074
1127
|
# expand attention_mask
|
|
1075
1128
|
if attention_mask is not None:
|
|
1076
1129
|
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
1077
|
-
attention_mask =
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
if output_hidden_states:
|
|
1083
|
-
encoder_states = encoder_states + (hidden_states,)
|
|
1084
|
-
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
|
1085
|
-
to_drop = False
|
|
1086
|
-
if self.training:
|
|
1087
|
-
dropout_probability = torch.rand([])
|
|
1088
|
-
if dropout_probability < self.layerdrop: # skip the layer
|
|
1089
|
-
to_drop = True
|
|
1130
|
+
attention_mask = create_bidirectional_mask(
|
|
1131
|
+
config=self.config,
|
|
1132
|
+
input_embeds=inputs_embeds,
|
|
1133
|
+
attention_mask=attention_mask,
|
|
1134
|
+
)
|
|
1090
1135
|
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
)
|
|
1136
|
+
for encoder_layer in self.layers:
|
|
1137
|
+
# we add spatial_position_embeddings as extra input to the encoder_layer
|
|
1138
|
+
hidden_states = encoder_layer(
|
|
1139
|
+
hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs
|
|
1140
|
+
)
|
|
1141
|
+
|
|
1142
|
+
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
1143
|
+
|
|
1144
|
+
|
|
1145
|
+
# function to generate sine positional embedding for 2d coordinates
|
|
1146
|
+
def gen_sine_position_embeddings(pos_tensor, d_model):
|
|
1147
|
+
scale = 2 * math.pi
|
|
1148
|
+
dim = d_model // 2
|
|
1149
|
+
dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
|
|
1150
|
+
dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
|
|
1151
|
+
x_embed = pos_tensor[:, :, 0] * scale
|
|
1152
|
+
y_embed = pos_tensor[:, :, 1] * scale
|
|
1153
|
+
pos_x = x_embed[:, :, None] / dim_t
|
|
1154
|
+
pos_y = y_embed[:, :, None] / dim_t
|
|
1155
|
+
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
1156
|
+
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
1157
|
+
pos = torch.cat((pos_y, pos_x), dim=2)
|
|
1158
|
+
return pos.to(pos_tensor.dtype)
|
|
1115
1159
|
|
|
1116
1160
|
|
|
1117
1161
|
class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|
@@ -1129,39 +1173,44 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|
|
1129
1173
|
config: ConditionalDetrConfig
|
|
1130
1174
|
"""
|
|
1131
1175
|
|
|
1176
|
+
_can_record_outputs = {
|
|
1177
|
+
"hidden_states": ConditionalDetrDecoderLayer,
|
|
1178
|
+
"attentions": OutputRecorder(ConditionalDetrDecoderSelfAttention, layer_name="self_attn", index=1),
|
|
1179
|
+
"cross_attentions": OutputRecorder(ConditionalDetrDecoderCrossAttention, layer_name="encoder_attn", index=1),
|
|
1180
|
+
}
|
|
1181
|
+
|
|
1132
1182
|
def __init__(self, config: ConditionalDetrConfig):
|
|
1133
1183
|
super().__init__(config)
|
|
1184
|
+
self.hidden_size = config.d_model
|
|
1185
|
+
|
|
1134
1186
|
self.dropout = config.dropout
|
|
1135
1187
|
self.layerdrop = config.decoder_layerdrop
|
|
1136
1188
|
|
|
1137
1189
|
self.layers = nn.ModuleList([ConditionalDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
|
|
1138
1190
|
# in Conditional DETR, the decoder uses layernorm after the last decoder layer output
|
|
1139
1191
|
self.layernorm = nn.LayerNorm(config.d_model)
|
|
1140
|
-
d_model = config.d_model
|
|
1141
|
-
self.gradient_checkpointing = False
|
|
1142
1192
|
|
|
1143
1193
|
# query_scale is the FFN applied on f to generate transformation T
|
|
1144
|
-
self.query_scale =
|
|
1145
|
-
self.ref_point_head =
|
|
1194
|
+
self.query_scale = ConditionalDetrMLPPredictionHead(self.hidden_size, self.hidden_size, self.hidden_size, 2)
|
|
1195
|
+
self.ref_point_head = ConditionalDetrMLPPredictionHead(self.hidden_size, self.hidden_size, 2, 2)
|
|
1146
1196
|
for layer_id in range(config.decoder_layers - 1):
|
|
1147
|
-
|
|
1197
|
+
# Set q_pos_proj to None for layers after the first (only first layer uses query position embeddings)
|
|
1198
|
+
self.layers[layer_id + 1].encoder_attn.q_pos_proj = None
|
|
1148
1199
|
|
|
1149
1200
|
# Initialize weights and apply final processing
|
|
1150
1201
|
self.post_init()
|
|
1151
1202
|
|
|
1203
|
+
@check_model_inputs()
|
|
1152
1204
|
def forward(
|
|
1153
1205
|
self,
|
|
1154
1206
|
inputs_embeds=None,
|
|
1155
1207
|
attention_mask=None,
|
|
1156
1208
|
encoder_hidden_states=None,
|
|
1157
1209
|
encoder_attention_mask=None,
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
return_dict=None,
|
|
1163
|
-
**kwargs,
|
|
1164
|
-
):
|
|
1210
|
+
spatial_position_embeddings=None,
|
|
1211
|
+
object_queries_position_embeddings=None,
|
|
1212
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1213
|
+
) -> ConditionalDetrDecoderOutput:
|
|
1165
1214
|
r"""
|
|
1166
1215
|
Args:
|
|
1167
1216
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
@@ -1184,46 +1233,28 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|
|
1184
1233
|
- 1 for pixels that are real (i.e. **not masked**),
|
|
1185
1234
|
- 0 for pixels that are padding (i.e. **masked**).
|
|
1186
1235
|
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1236
|
+
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1237
|
+
Spatial position embeddings that are added to the queries and keys in each cross-attention layer.
|
|
1238
|
+
object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
1190
1239
|
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
|
|
1191
|
-
output_attentions (`bool`, *optional*):
|
|
1192
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
1193
|
-
returned tensors for more detail.
|
|
1194
|
-
output_hidden_states (`bool`, *optional*):
|
|
1195
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
1196
|
-
for more detail.
|
|
1197
|
-
return_dict (`bool`, *optional*):
|
|
1198
|
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
1199
1240
|
"""
|
|
1200
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1201
|
-
output_hidden_states = (
|
|
1202
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1203
|
-
)
|
|
1204
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1205
|
-
|
|
1206
1241
|
if inputs_embeds is not None:
|
|
1207
1242
|
hidden_states = inputs_embeds
|
|
1208
|
-
input_shape = inputs_embeds.size()[:-1]
|
|
1209
1243
|
|
|
1210
1244
|
# expand encoder attention mask
|
|
1211
1245
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
|
1212
1246
|
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
1213
|
-
encoder_attention_mask =
|
|
1214
|
-
|
|
1247
|
+
encoder_attention_mask = create_bidirectional_mask(
|
|
1248
|
+
self.config,
|
|
1249
|
+
inputs_embeds,
|
|
1250
|
+
encoder_attention_mask,
|
|
1215
1251
|
)
|
|
1216
1252
|
|
|
1217
1253
|
# optional intermediate hidden states
|
|
1218
1254
|
intermediate = () if self.config.auxiliary_loss else None
|
|
1219
1255
|
|
|
1220
|
-
# decoder layers
|
|
1221
|
-
all_hidden_states = () if output_hidden_states else None
|
|
1222
|
-
all_self_attns = () if output_attentions else None
|
|
1223
|
-
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
1224
|
-
|
|
1225
1256
|
reference_points_before_sigmoid = self.ref_point_head(
|
|
1226
|
-
|
|
1257
|
+
object_queries_position_embeddings
|
|
1227
1258
|
) # [num_queries, batch_size, 2]
|
|
1228
1259
|
reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1)
|
|
1229
1260
|
obj_center = reference_points[..., :2].transpose(0, 1)
|
|
@@ -1231,9 +1262,6 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|
|
1231
1262
|
query_sine_embed_before_transformation = gen_sine_position_embeddings(obj_center, self.config.d_model)
|
|
1232
1263
|
|
|
1233
1264
|
for idx, decoder_layer in enumerate(self.layers):
|
|
1234
|
-
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
|
1235
|
-
if output_hidden_states:
|
|
1236
|
-
all_hidden_states += (hidden_states,)
|
|
1237
1265
|
if self.training:
|
|
1238
1266
|
dropout_probability = torch.rand([])
|
|
1239
1267
|
if dropout_probability < self.layerdrop:
|
|
@@ -1245,59 +1273,31 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|
|
1245
1273
|
# apply transformation
|
|
1246
1274
|
query_sine_embed = query_sine_embed_before_transformation * pos_transformation
|
|
1247
1275
|
|
|
1248
|
-
|
|
1276
|
+
hidden_states = decoder_layer(
|
|
1249
1277
|
hidden_states,
|
|
1250
|
-
None,
|
|
1251
|
-
|
|
1252
|
-
|
|
1278
|
+
None,
|
|
1279
|
+
spatial_position_embeddings,
|
|
1280
|
+
object_queries_position_embeddings,
|
|
1253
1281
|
query_sine_embed,
|
|
1254
1282
|
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
|
1255
1283
|
encoder_attention_mask=encoder_attention_mask,
|
|
1256
|
-
output_attentions=output_attentions,
|
|
1257
1284
|
is_first=(idx == 0),
|
|
1285
|
+
**kwargs,
|
|
1258
1286
|
)
|
|
1259
1287
|
|
|
1260
|
-
hidden_states = layer_outputs[0]
|
|
1261
|
-
|
|
1262
1288
|
if self.config.auxiliary_loss:
|
|
1263
1289
|
hidden_states = self.layernorm(hidden_states)
|
|
1264
1290
|
intermediate += (hidden_states,)
|
|
1265
1291
|
|
|
1266
|
-
if output_attentions:
|
|
1267
|
-
all_self_attns += (layer_outputs[1],)
|
|
1268
|
-
|
|
1269
|
-
if encoder_hidden_states is not None:
|
|
1270
|
-
all_cross_attentions += (layer_outputs[2],)
|
|
1271
|
-
|
|
1272
1292
|
# finally, apply layernorm
|
|
1273
1293
|
hidden_states = self.layernorm(hidden_states)
|
|
1274
1294
|
|
|
1275
|
-
# add hidden states from the last decoder layer
|
|
1276
|
-
if output_hidden_states:
|
|
1277
|
-
all_hidden_states += (hidden_states,)
|
|
1278
|
-
|
|
1279
1295
|
# stack intermediate decoder activations
|
|
1280
1296
|
if self.config.auxiliary_loss:
|
|
1281
1297
|
intermediate = torch.stack(intermediate)
|
|
1282
1298
|
|
|
1283
|
-
if not return_dict:
|
|
1284
|
-
return tuple(
|
|
1285
|
-
v
|
|
1286
|
-
for v in [
|
|
1287
|
-
hidden_states,
|
|
1288
|
-
all_hidden_states,
|
|
1289
|
-
all_self_attns,
|
|
1290
|
-
all_cross_attentions,
|
|
1291
|
-
intermediate,
|
|
1292
|
-
reference_points,
|
|
1293
|
-
]
|
|
1294
|
-
if v is not None
|
|
1295
|
-
)
|
|
1296
1299
|
return ConditionalDetrDecoderOutput(
|
|
1297
1300
|
last_hidden_state=hidden_states,
|
|
1298
|
-
hidden_states=all_hidden_states,
|
|
1299
|
-
attentions=all_self_attns,
|
|
1300
|
-
cross_attentions=all_cross_attentions,
|
|
1301
1301
|
intermediate_hidden_states=intermediate,
|
|
1302
1302
|
reference_points=reference_points,
|
|
1303
1303
|
)
|
|
@@ -1305,23 +1305,24 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|
|
1305
1305
|
|
|
1306
1306
|
@auto_docstring(
|
|
1307
1307
|
custom_intro="""
|
|
1308
|
-
The bare
|
|
1309
|
-
|
|
1308
|
+
The bare CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
|
|
1309
|
+
any specific head on top.
|
|
1310
1310
|
"""
|
|
1311
1311
|
)
|
|
1312
1312
|
class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
1313
1313
|
def __init__(self, config: ConditionalDetrConfig):
|
|
1314
1314
|
super().__init__(config)
|
|
1315
1315
|
|
|
1316
|
-
|
|
1317
|
-
backbone = ConditionalDetrConvEncoder(config)
|
|
1318
|
-
object_queries = build_position_encoding(config)
|
|
1319
|
-
self.backbone = ConditionalDetrConvModel(backbone, object_queries)
|
|
1320
|
-
|
|
1321
|
-
# Create projection layer
|
|
1322
|
-
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
|
1316
|
+
self.backbone = ConditionalDetrConvEncoder(config)
|
|
1323
1317
|
|
|
1318
|
+
if config.position_embedding_type == "sine":
|
|
1319
|
+
self.position_embedding = ConditionalDetrSinePositionEmbedding(config.d_model // 2, normalize=True)
|
|
1320
|
+
elif config.position_embedding_type == "learned":
|
|
1321
|
+
self.position_embedding = ConditionalDetrLearnedPositionEmbedding(config.d_model // 2)
|
|
1322
|
+
else:
|
|
1323
|
+
raise ValueError(f"Not supported {config.position_embedding_type}")
|
|
1324
1324
|
self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
|
|
1325
|
+
self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
|
1325
1326
|
|
|
1326
1327
|
self.encoder = ConditionalDetrEncoder(config)
|
|
1327
1328
|
self.decoder = ConditionalDetrDecoder(config)
|
|
@@ -1330,27 +1331,25 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
|
1330
1331
|
self.post_init()
|
|
1331
1332
|
|
|
1332
1333
|
def freeze_backbone(self):
|
|
1333
|
-
for
|
|
1334
|
+
for _, param in self.backbone.model.named_parameters():
|
|
1334
1335
|
param.requires_grad_(False)
|
|
1335
1336
|
|
|
1336
1337
|
def unfreeze_backbone(self):
|
|
1337
|
-
for
|
|
1338
|
+
for _, param in self.backbone.model.named_parameters():
|
|
1338
1339
|
param.requires_grad_(True)
|
|
1339
1340
|
|
|
1340
1341
|
@auto_docstring
|
|
1342
|
+
@can_return_tuple
|
|
1341
1343
|
def forward(
|
|
1342
1344
|
self,
|
|
1343
1345
|
pixel_values: torch.FloatTensor,
|
|
1344
|
-
pixel_mask:
|
|
1345
|
-
decoder_attention_mask:
|
|
1346
|
-
encoder_outputs:
|
|
1347
|
-
inputs_embeds:
|
|
1348
|
-
decoder_inputs_embeds:
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
return_dict: Optional[bool] = None,
|
|
1352
|
-
**kwargs,
|
|
1353
|
-
) -> Union[tuple[torch.FloatTensor], ConditionalDetrModelOutput]:
|
|
1346
|
+
pixel_mask: torch.LongTensor | None = None,
|
|
1347
|
+
decoder_attention_mask: torch.LongTensor | None = None,
|
|
1348
|
+
encoder_outputs: torch.FloatTensor | None = None,
|
|
1349
|
+
inputs_embeds: torch.FloatTensor | None = None,
|
|
1350
|
+
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1351
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1352
|
+
) -> ConditionalDetrModelOutput:
|
|
1354
1353
|
r"""
|
|
1355
1354
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
1356
1355
|
Not used by default. Can be used to mask object queries.
|
|
@@ -1386,12 +1385,6 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
|
1386
1385
|
>>> list(last_hidden_states.shape)
|
|
1387
1386
|
[1, 300, 256]
|
|
1388
1387
|
```"""
|
|
1389
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1390
|
-
output_hidden_states = (
|
|
1391
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1392
|
-
)
|
|
1393
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1394
|
-
|
|
1395
1388
|
batch_size, num_channels, height, width = pixel_values.shape
|
|
1396
1389
|
device = pixel_values.device
|
|
1397
1390
|
|
|
@@ -1401,7 +1394,7 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
|
1401
1394
|
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
|
|
1402
1395
|
# pixel_values should be of shape (batch_size, num_channels, height, width)
|
|
1403
1396
|
# pixel_mask should be of shape (batch_size, height, width)
|
|
1404
|
-
features
|
|
1397
|
+
features = self.backbone(pixel_values, pixel_mask)
|
|
1405
1398
|
|
|
1406
1399
|
# get final feature map and downsampled mask
|
|
1407
1400
|
feature_map, mask = features[-1]
|
|
@@ -1412,53 +1405,52 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
|
1412
1405
|
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
|
1413
1406
|
projected_feature_map = self.input_projection(feature_map)
|
|
1414
1407
|
|
|
1415
|
-
#
|
|
1408
|
+
# Generate position embeddings
|
|
1409
|
+
spatial_position_embeddings = self.position_embedding(
|
|
1410
|
+
shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
|
|
1411
|
+
)
|
|
1412
|
+
|
|
1413
|
+
# Third, flatten the feature map of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
|
1416
1414
|
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
|
1417
1415
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
|
1418
|
-
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
|
|
1419
1416
|
|
|
1420
1417
|
flattened_mask = mask.flatten(1)
|
|
1421
1418
|
|
|
1422
|
-
# Fourth, sent flattened_features + flattened_mask +
|
|
1419
|
+
# Fourth, sent flattened_features + flattened_mask + spatial_position_embeddings through encoder
|
|
1423
1420
|
# flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
|
|
1424
1421
|
# flattened_mask is a Tensor of shape (batch_size, height*width)
|
|
1425
1422
|
if encoder_outputs is None:
|
|
1426
1423
|
encoder_outputs = self.encoder(
|
|
1427
1424
|
inputs_embeds=flattened_features,
|
|
1428
1425
|
attention_mask=flattened_mask,
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
output_hidden_states=output_hidden_states,
|
|
1432
|
-
return_dict=return_dict,
|
|
1426
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1427
|
+
**kwargs,
|
|
1433
1428
|
)
|
|
1434
|
-
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
|
1435
|
-
elif
|
|
1429
|
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
|
1430
|
+
elif not isinstance(encoder_outputs, BaseModelOutput):
|
|
1436
1431
|
encoder_outputs = BaseModelOutput(
|
|
1437
1432
|
last_hidden_state=encoder_outputs[0],
|
|
1438
1433
|
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1439
1434
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1440
1435
|
)
|
|
1441
1436
|
|
|
1442
|
-
# Fifth, sent query embeddings
|
|
1443
|
-
|
|
1444
|
-
|
|
1437
|
+
# Fifth, sent query embeddings through the decoder (which is conditioned on the encoder output)
|
|
1438
|
+
object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(
|
|
1439
|
+
batch_size, 1, 1
|
|
1440
|
+
)
|
|
1441
|
+
queries = torch.zeros_like(object_queries_position_embeddings)
|
|
1445
1442
|
|
|
1446
1443
|
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
|
|
1447
1444
|
decoder_outputs = self.decoder(
|
|
1448
1445
|
inputs_embeds=queries,
|
|
1449
1446
|
attention_mask=None,
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
encoder_hidden_states=encoder_outputs
|
|
1447
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1448
|
+
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
1449
|
+
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
1453
1450
|
encoder_attention_mask=flattened_mask,
|
|
1454
|
-
|
|
1455
|
-
output_hidden_states=output_hidden_states,
|
|
1456
|
-
return_dict=return_dict,
|
|
1451
|
+
**kwargs,
|
|
1457
1452
|
)
|
|
1458
1453
|
|
|
1459
|
-
if not return_dict:
|
|
1460
|
-
return decoder_outputs + encoder_outputs
|
|
1461
|
-
|
|
1462
1454
|
return ConditionalDetrModelOutput(
|
|
1463
1455
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
1464
1456
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
@@ -1472,45 +1464,26 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|
|
1472
1464
|
)
|
|
1473
1465
|
|
|
1474
1466
|
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
1482
|
-
|
|
1483
|
-
"""
|
|
1484
|
-
|
|
1485
|
-
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
1486
|
-
super().__init__()
|
|
1487
|
-
self.num_layers = num_layers
|
|
1488
|
-
h = [hidden_dim] * (num_layers - 1)
|
|
1489
|
-
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
|
1490
|
-
|
|
1491
|
-
def forward(self, x):
|
|
1492
|
-
for i, layer in enumerate(self.layers):
|
|
1493
|
-
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
1494
|
-
return x
|
|
1467
|
+
def inverse_sigmoid(x, eps=1e-5):
|
|
1468
|
+
x = x.clamp(min=0, max=1)
|
|
1469
|
+
x1 = x.clamp(min=eps)
|
|
1470
|
+
x2 = (1 - x).clamp(min=eps)
|
|
1471
|
+
return torch.log(x1 / x2)
|
|
1495
1472
|
|
|
1496
1473
|
|
|
1497
1474
|
@auto_docstring(
|
|
1498
1475
|
custom_intro="""
|
|
1499
|
-
|
|
1500
|
-
|
|
1476
|
+
CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
|
|
1477
|
+
such as COCO detection.
|
|
1501
1478
|
"""
|
|
1502
1479
|
)
|
|
1503
1480
|
class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
1504
1481
|
def __init__(self, config: ConditionalDetrConfig):
|
|
1505
1482
|
super().__init__(config)
|
|
1506
1483
|
|
|
1507
|
-
#
|
|
1484
|
+
# CONDITIONAL_DETR encoder-decoder model
|
|
1508
1485
|
self.model = ConditionalDetrModel(config)
|
|
1509
|
-
|
|
1510
|
-
# Object detection heads
|
|
1511
|
-
self.class_labels_classifier = nn.Linear(
|
|
1512
|
-
config.d_model, config.num_labels
|
|
1513
|
-
) # We add one for the "no object" class
|
|
1486
|
+
self.class_labels_classifier = nn.Linear(config.d_model, config.num_labels)
|
|
1514
1487
|
self.bbox_predictor = ConditionalDetrMLPPredictionHead(
|
|
1515
1488
|
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
|
|
1516
1489
|
)
|
|
@@ -1518,25 +1491,19 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1518
1491
|
# Initialize weights and apply final processing
|
|
1519
1492
|
self.post_init()
|
|
1520
1493
|
|
|
1521
|
-
# taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/conditional_detr.py
|
|
1522
|
-
def _set_aux_loss(self, outputs_class, outputs_coord):
|
|
1523
|
-
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
|
1524
|
-
|
|
1525
1494
|
@auto_docstring
|
|
1495
|
+
@can_return_tuple
|
|
1526
1496
|
def forward(
|
|
1527
1497
|
self,
|
|
1528
1498
|
pixel_values: torch.FloatTensor,
|
|
1529
|
-
pixel_mask:
|
|
1530
|
-
decoder_attention_mask:
|
|
1531
|
-
encoder_outputs:
|
|
1532
|
-
inputs_embeds:
|
|
1533
|
-
decoder_inputs_embeds:
|
|
1534
|
-
labels:
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
return_dict: Optional[bool] = None,
|
|
1538
|
-
**kwargs,
|
|
1539
|
-
) -> Union[tuple[torch.FloatTensor], ConditionalDetrObjectDetectionOutput]:
|
|
1499
|
+
pixel_mask: torch.LongTensor | None = None,
|
|
1500
|
+
decoder_attention_mask: torch.LongTensor | None = None,
|
|
1501
|
+
encoder_outputs: torch.FloatTensor | None = None,
|
|
1502
|
+
inputs_embeds: torch.FloatTensor | None = None,
|
|
1503
|
+
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1504
|
+
labels: list[dict] | None = None,
|
|
1505
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1506
|
+
) -> ConditionalDetrObjectDetectionOutput:
|
|
1540
1507
|
r"""
|
|
1541
1508
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
1542
1509
|
Not used by default. Can be used to mask object queries.
|
|
@@ -1586,8 +1553,6 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1586
1553
|
Detected remote with confidence 0.683 at location [334.48, 73.49, 366.37, 190.01]
|
|
1587
1554
|
Detected couch with confidence 0.535 at location [0.52, 1.19, 640.35, 475.1]
|
|
1588
1555
|
```"""
|
|
1589
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1590
|
-
|
|
1591
1556
|
# First, sent images through CONDITIONAL_DETR base model to obtain encoder + decoder outputs
|
|
1592
1557
|
outputs = self.model(
|
|
1593
1558
|
pixel_values,
|
|
@@ -1596,9 +1561,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1596
1561
|
encoder_outputs=encoder_outputs,
|
|
1597
1562
|
inputs_embeds=inputs_embeds,
|
|
1598
1563
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
1599
|
-
|
|
1600
|
-
output_hidden_states=output_hidden_states,
|
|
1601
|
-
return_dict=return_dict,
|
|
1564
|
+
**kwargs,
|
|
1602
1565
|
)
|
|
1603
1566
|
|
|
1604
1567
|
sequence_output = outputs[0]
|
|
@@ -1606,11 +1569,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1606
1569
|
# class logits + predicted bounding boxes
|
|
1607
1570
|
logits = self.class_labels_classifier(sequence_output)
|
|
1608
1571
|
|
|
1609
|
-
|
|
1610
|
-
# are not specified, otherwise it will be another index which is hard to determine.
|
|
1611
|
-
# Leave it as is, because it's not a common case to use
|
|
1612
|
-
# return_dict=False + output_attentions=True / output_hidden_states=True
|
|
1613
|
-
reference = outputs.reference_points if return_dict else outputs[-2]
|
|
1572
|
+
reference = outputs.reference_points
|
|
1614
1573
|
reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
|
|
1615
1574
|
|
|
1616
1575
|
hs = sequence_output
|
|
@@ -1624,7 +1583,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1624
1583
|
outputs_class, outputs_coord = None, None
|
|
1625
1584
|
if self.config.auxiliary_loss:
|
|
1626
1585
|
outputs_coords = []
|
|
1627
|
-
intermediate = outputs.intermediate_hidden_states
|
|
1586
|
+
intermediate = outputs.intermediate_hidden_states
|
|
1628
1587
|
outputs_class = self.class_labels_classifier(intermediate)
|
|
1629
1588
|
for lvl in range(intermediate.shape[0]):
|
|
1630
1589
|
tmp = self.bbox_predictor(intermediate[lvl])
|
|
@@ -1636,13 +1595,6 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1636
1595
|
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
|
1637
1596
|
)
|
|
1638
1597
|
|
|
1639
|
-
if not return_dict:
|
|
1640
|
-
if auxiliary_outputs is not None:
|
|
1641
|
-
output = (logits, pred_boxes) + auxiliary_outputs + outputs
|
|
1642
|
-
else:
|
|
1643
|
-
output = (logits, pred_boxes) + outputs
|
|
1644
|
-
return ((loss, loss_dict) + output) if loss is not None else output
|
|
1645
|
-
|
|
1646
1598
|
return ConditionalDetrObjectDetectionOutput(
|
|
1647
1599
|
loss=loss,
|
|
1648
1600
|
loss_dict=loss_dict,
|
|
@@ -1658,14 +1610,38 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
|
|
1658
1610
|
encoder_attentions=outputs.encoder_attentions,
|
|
1659
1611
|
)
|
|
1660
1612
|
|
|
1613
|
+
# taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/conditional_detr.py
|
|
1614
|
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
|
1615
|
+
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
|
1616
|
+
|
|
1661
1617
|
|
|
1662
1618
|
@auto_docstring(
|
|
1663
1619
|
custom_intro="""
|
|
1664
|
-
|
|
1665
|
-
|
|
1620
|
+
CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
|
|
1621
|
+
such as COCO panoptic.
|
|
1666
1622
|
"""
|
|
1667
1623
|
)
|
|
1668
1624
|
class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
1625
|
+
_checkpoint_conversion_mapping = {
|
|
1626
|
+
"bbox_attention.q_linear": "bbox_attention.q_proj",
|
|
1627
|
+
"bbox_attention.k_linear": "bbox_attention.k_proj",
|
|
1628
|
+
# Mask head refactor
|
|
1629
|
+
"mask_head.lay1": "mask_head.conv1.conv",
|
|
1630
|
+
"mask_head.gn1": "mask_head.conv1.norm",
|
|
1631
|
+
"mask_head.lay2": "mask_head.conv2.conv",
|
|
1632
|
+
"mask_head.gn2": "mask_head.conv2.norm",
|
|
1633
|
+
"mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter",
|
|
1634
|
+
"mask_head.lay3": "mask_head.fpn_stages.0.refine.conv",
|
|
1635
|
+
"mask_head.gn3": "mask_head.fpn_stages.0.refine.norm",
|
|
1636
|
+
"mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter",
|
|
1637
|
+
"mask_head.lay4": "mask_head.fpn_stages.1.refine.conv",
|
|
1638
|
+
"mask_head.gn4": "mask_head.fpn_stages.1.refine.norm",
|
|
1639
|
+
"mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter",
|
|
1640
|
+
"mask_head.lay5": "mask_head.fpn_stages.2.refine.conv",
|
|
1641
|
+
"mask_head.gn5": "mask_head.fpn_stages.2.refine.norm",
|
|
1642
|
+
"mask_head.out_lay": "mask_head.output_conv",
|
|
1643
|
+
}
|
|
1644
|
+
|
|
1669
1645
|
def __init__(self, config: ConditionalDetrConfig):
|
|
1670
1646
|
super().__init__(config)
|
|
1671
1647
|
|
|
@@ -1674,43 +1650,44 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
|
1674
1650
|
|
|
1675
1651
|
# segmentation head
|
|
1676
1652
|
hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
|
|
1677
|
-
intermediate_channel_sizes = self.conditional_detr.model.backbone.
|
|
1653
|
+
intermediate_channel_sizes = self.conditional_detr.model.backbone.intermediate_channel_sizes
|
|
1678
1654
|
|
|
1679
1655
|
self.mask_head = ConditionalDetrMaskHeadSmallConv(
|
|
1680
|
-
hidden_size + number_of_heads,
|
|
1681
|
-
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
|
|
1656
|
+
input_channels=hidden_size + number_of_heads,
|
|
1657
|
+
fpn_channels=intermediate_channel_sizes[::-1][-3:],
|
|
1658
|
+
hidden_size=hidden_size,
|
|
1659
|
+
activation_function=config.activation_function,
|
|
1685
1660
|
)
|
|
1686
1661
|
|
|
1662
|
+
self.bbox_attention = ConditionalDetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0)
|
|
1687
1663
|
# Initialize weights and apply final processing
|
|
1688
1664
|
self.post_init()
|
|
1689
1665
|
|
|
1690
1666
|
@auto_docstring
|
|
1667
|
+
@can_return_tuple
|
|
1691
1668
|
def forward(
|
|
1692
1669
|
self,
|
|
1693
1670
|
pixel_values: torch.FloatTensor,
|
|
1694
|
-
pixel_mask:
|
|
1695
|
-
decoder_attention_mask:
|
|
1696
|
-
encoder_outputs:
|
|
1697
|
-
inputs_embeds:
|
|
1698
|
-
decoder_inputs_embeds:
|
|
1699
|
-
labels:
|
|
1700
|
-
|
|
1701
|
-
|
|
1702
|
-
return_dict: Optional[bool] = None,
|
|
1703
|
-
**kwargs,
|
|
1704
|
-
) -> Union[tuple[torch.FloatTensor], ConditionalDetrSegmentationOutput]:
|
|
1671
|
+
pixel_mask: torch.LongTensor | None = None,
|
|
1672
|
+
decoder_attention_mask: torch.FloatTensor | None = None,
|
|
1673
|
+
encoder_outputs: torch.FloatTensor | None = None,
|
|
1674
|
+
inputs_embeds: torch.FloatTensor | None = None,
|
|
1675
|
+
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1676
|
+
labels: list[dict] | None = None,
|
|
1677
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1678
|
+
) -> tuple[torch.FloatTensor] | ConditionalDetrSegmentationOutput:
|
|
1705
1679
|
r"""
|
|
1706
1680
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
1707
|
-
|
|
1681
|
+
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
|
|
1682
|
+
|
|
1683
|
+
- 1 for queries that are **not masked**,
|
|
1684
|
+
- 0 for queries that are **masked**.
|
|
1708
1685
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1709
|
-
|
|
1710
|
-
|
|
1686
|
+
Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires
|
|
1687
|
+
multi-scale features from the backbone that are not available when bypassing it with inputs_embeds.
|
|
1711
1688
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1712
1689
|
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1713
|
-
embedded representation.
|
|
1690
|
+
embedded representation. Useful for tasks that require custom query initialization.
|
|
1714
1691
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1715
1692
|
Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
|
|
1716
1693
|
dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
|
|
@@ -1723,26 +1700,21 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
|
1723
1700
|
|
|
1724
1701
|
```python
|
|
1725
1702
|
>>> import io
|
|
1726
|
-
>>> import
|
|
1703
|
+
>>> import httpx
|
|
1704
|
+
>>> from io import BytesIO
|
|
1727
1705
|
>>> from PIL import Image
|
|
1728
1706
|
>>> import torch
|
|
1729
1707
|
>>> import numpy
|
|
1730
1708
|
|
|
1731
|
-
>>> from transformers import
|
|
1732
|
-
... AutoImageProcessor,
|
|
1733
|
-
... ConditionalDetrConfig,
|
|
1734
|
-
... ConditionalDetrForSegmentation,
|
|
1735
|
-
... )
|
|
1709
|
+
>>> from transformers import AutoImageProcessor, ConditionalDetrForSegmentation
|
|
1736
1710
|
>>> from transformers.image_transforms import rgb_to_id
|
|
1737
1711
|
|
|
1738
1712
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
1739
|
-
>>>
|
|
1740
|
-
|
|
1741
|
-
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50")
|
|
1713
|
+
>>> with httpx.stream("GET", url) as response:
|
|
1714
|
+
... image = Image.open(BytesIO(response.read()))
|
|
1742
1715
|
|
|
1743
|
-
>>>
|
|
1744
|
-
>>>
|
|
1745
|
-
>>> model = ConditionalDetrForSegmentation(config)
|
|
1716
|
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/conditional_detr-resnet-50-panoptic")
|
|
1717
|
+
>>> model = ConditionalDetrForSegmentation.from_pretrained("facebook/conditional_detr-resnet-50-panoptic")
|
|
1746
1718
|
|
|
1747
1719
|
>>> # prepare image for the model
|
|
1748
1720
|
>>> inputs = image_processor(images=image, return_tensors="pt")
|
|
@@ -1753,89 +1725,88 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
|
1753
1725
|
>>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
|
|
1754
1726
|
>>> # Segmentation results are returned as a list of dictionaries
|
|
1755
1727
|
>>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
|
|
1728
|
+
|
|
1756
1729
|
>>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
|
|
1757
1730
|
>>> panoptic_seg = result[0]["segmentation"]
|
|
1731
|
+
>>> panoptic_seg.shape
|
|
1732
|
+
torch.Size([300, 500])
|
|
1758
1733
|
>>> # Get prediction score and segment_id to class_id mapping of each segment
|
|
1759
1734
|
>>> panoptic_segments_info = result[0]["segments_info"]
|
|
1735
|
+
>>> len(panoptic_segments_info)
|
|
1736
|
+
5
|
|
1760
1737
|
```"""
|
|
1761
1738
|
|
|
1762
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1763
|
-
|
|
1764
1739
|
batch_size, num_channels, height, width = pixel_values.shape
|
|
1765
1740
|
device = pixel_values.device
|
|
1766
1741
|
|
|
1767
1742
|
if pixel_mask is None:
|
|
1768
1743
|
pixel_mask = torch.ones((batch_size, height, width), device=device)
|
|
1769
1744
|
|
|
1770
|
-
|
|
1771
|
-
|
|
1745
|
+
vision_features = self.conditional_detr.model.backbone(pixel_values, pixel_mask)
|
|
1746
|
+
feature_map, mask = vision_features[-1]
|
|
1772
1747
|
|
|
1773
|
-
#
|
|
1774
|
-
feature_map, mask = features[-1]
|
|
1775
|
-
batch_size, num_channels, height, width = feature_map.shape
|
|
1748
|
+
# Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
|
|
1776
1749
|
projected_feature_map = self.conditional_detr.model.input_projection(feature_map)
|
|
1777
|
-
|
|
1778
|
-
# Third, flatten the feature map + object_queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
|
1779
|
-
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
|
1780
1750
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
|
1781
|
-
|
|
1782
|
-
|
|
1751
|
+
spatial_position_embeddings = self.conditional_detr.model.position_embedding(
|
|
1752
|
+
shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
|
|
1753
|
+
)
|
|
1783
1754
|
flattened_mask = mask.flatten(1)
|
|
1784
1755
|
|
|
1785
|
-
# Fourth, sent flattened_features + flattened_mask + object_queries through encoder
|
|
1786
|
-
# flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
|
|
1787
|
-
# flattened_mask is a Tensor of shape (batch_size, height*width)
|
|
1788
1756
|
if encoder_outputs is None:
|
|
1789
1757
|
encoder_outputs = self.conditional_detr.model.encoder(
|
|
1790
1758
|
inputs_embeds=flattened_features,
|
|
1791
1759
|
attention_mask=flattened_mask,
|
|
1792
|
-
|
|
1793
|
-
|
|
1794
|
-
output_hidden_states=output_hidden_states,
|
|
1795
|
-
return_dict=return_dict,
|
|
1796
|
-
)
|
|
1797
|
-
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
|
1798
|
-
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
|
1799
|
-
encoder_outputs = BaseModelOutput(
|
|
1800
|
-
last_hidden_state=encoder_outputs[0],
|
|
1801
|
-
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1802
|
-
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1760
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1761
|
+
**kwargs,
|
|
1803
1762
|
)
|
|
1804
1763
|
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
queries
|
|
1764
|
+
object_queries_position_embeddings = self.conditional_detr.model.query_position_embeddings.weight.unsqueeze(
|
|
1765
|
+
0
|
|
1766
|
+
).repeat(batch_size, 1, 1)
|
|
1767
|
+
|
|
1768
|
+
# Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
|
|
1769
|
+
if decoder_inputs_embeds is not None:
|
|
1770
|
+
queries = decoder_inputs_embeds
|
|
1771
|
+
else:
|
|
1772
|
+
queries = torch.zeros_like(object_queries_position_embeddings)
|
|
1810
1773
|
|
|
1811
|
-
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
|
|
1812
1774
|
decoder_outputs = self.conditional_detr.model.decoder(
|
|
1813
1775
|
inputs_embeds=queries,
|
|
1814
|
-
attention_mask=
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
encoder_hidden_states=encoder_outputs
|
|
1776
|
+
attention_mask=decoder_attention_mask,
|
|
1777
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1778
|
+
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
1779
|
+
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
1818
1780
|
encoder_attention_mask=flattened_mask,
|
|
1819
|
-
|
|
1820
|
-
output_hidden_states=output_hidden_states,
|
|
1821
|
-
return_dict=return_dict,
|
|
1781
|
+
**kwargs,
|
|
1822
1782
|
)
|
|
1823
1783
|
|
|
1824
1784
|
sequence_output = decoder_outputs[0]
|
|
1825
1785
|
|
|
1826
|
-
# Sixth, compute logits, pred_boxes and pred_masks
|
|
1827
1786
|
logits = self.conditional_detr.class_labels_classifier(sequence_output)
|
|
1828
1787
|
pred_boxes = self.conditional_detr.bbox_predictor(sequence_output).sigmoid()
|
|
1829
1788
|
|
|
1830
|
-
|
|
1831
|
-
|
|
1789
|
+
height, width = feature_map.shape[-2:]
|
|
1790
|
+
memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view(
|
|
1791
|
+
batch_size, self.config.d_model, height, width
|
|
1792
|
+
)
|
|
1793
|
+
attention_mask = flattened_mask.view(batch_size, height, width)
|
|
1832
1794
|
|
|
1833
|
-
|
|
1834
|
-
|
|
1835
|
-
|
|
1836
|
-
|
|
1795
|
+
if attention_mask is not None:
|
|
1796
|
+
min_dtype = torch.finfo(memory.dtype).min
|
|
1797
|
+
attention_mask = torch.where(
|
|
1798
|
+
attention_mask.unsqueeze(1).unsqueeze(1),
|
|
1799
|
+
torch.tensor(0.0, device=memory.device, dtype=memory.dtype),
|
|
1800
|
+
min_dtype,
|
|
1801
|
+
)
|
|
1837
1802
|
|
|
1838
|
-
|
|
1803
|
+
bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=attention_mask)
|
|
1804
|
+
|
|
1805
|
+
seg_masks = self.mask_head(
|
|
1806
|
+
features=projected_feature_map,
|
|
1807
|
+
attention_masks=bbox_mask,
|
|
1808
|
+
fpn_features=[vision_features[2][0], vision_features[1][0], vision_features[0][0]],
|
|
1809
|
+
)
|
|
1839
1810
|
|
|
1840
1811
|
pred_masks = seg_masks.view(
|
|
1841
1812
|
batch_size, self.conditional_detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]
|
|
@@ -1845,20 +1816,13 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
|
1845
1816
|
if labels is not None:
|
|
1846
1817
|
outputs_class, outputs_coord = None, None
|
|
1847
1818
|
if self.config.auxiliary_loss:
|
|
1848
|
-
intermediate = decoder_outputs.intermediate_hidden_states
|
|
1819
|
+
intermediate = decoder_outputs.intermediate_hidden_states
|
|
1849
1820
|
outputs_class = self.conditional_detr.class_labels_classifier(intermediate)
|
|
1850
1821
|
outputs_coord = self.conditional_detr.bbox_predictor(intermediate).sigmoid()
|
|
1851
1822
|
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
|
1852
|
-
logits, labels,
|
|
1823
|
+
logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
|
|
1853
1824
|
)
|
|
1854
1825
|
|
|
1855
|
-
if not return_dict:
|
|
1856
|
-
if auxiliary_outputs is not None:
|
|
1857
|
-
output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
|
|
1858
|
-
else:
|
|
1859
|
-
output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
|
|
1860
|
-
return ((loss, loss_dict) + output) if loss is not None else output
|
|
1861
|
-
|
|
1862
1826
|
return ConditionalDetrSegmentationOutput(
|
|
1863
1827
|
loss=loss,
|
|
1864
1828
|
loss_dict=loss_dict,
|
|
@@ -1876,120 +1840,6 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|
|
1876
1840
|
)
|
|
1877
1841
|
|
|
1878
1842
|
|
|
1879
|
-
def _expand(tensor, length: int):
|
|
1880
|
-
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
|
|
1881
|
-
|
|
1882
|
-
|
|
1883
|
-
# Copied from transformers.models.detr.modeling_detr.DetrMaskHeadSmallConv with Detr->ConditionalDetr
|
|
1884
|
-
class ConditionalDetrMaskHeadSmallConv(nn.Module):
|
|
1885
|
-
"""
|
|
1886
|
-
Simple convolutional head, using group norm. Upsampling is done using a FPN approach
|
|
1887
|
-
"""
|
|
1888
|
-
|
|
1889
|
-
def __init__(self, dim, fpn_dims, context_dim):
|
|
1890
|
-
super().__init__()
|
|
1891
|
-
|
|
1892
|
-
if dim % 8 != 0:
|
|
1893
|
-
raise ValueError(
|
|
1894
|
-
"The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
|
|
1895
|
-
" GroupNorm is set to 8"
|
|
1896
|
-
)
|
|
1897
|
-
|
|
1898
|
-
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
|
|
1899
|
-
|
|
1900
|
-
self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
|
|
1901
|
-
self.gn1 = nn.GroupNorm(8, dim)
|
|
1902
|
-
self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
|
|
1903
|
-
self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1])
|
|
1904
|
-
self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
|
|
1905
|
-
self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2])
|
|
1906
|
-
self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
|
|
1907
|
-
self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3])
|
|
1908
|
-
self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
|
|
1909
|
-
self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4])
|
|
1910
|
-
self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
|
|
1911
|
-
|
|
1912
|
-
self.dim = dim
|
|
1913
|
-
|
|
1914
|
-
self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
|
|
1915
|
-
self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
|
|
1916
|
-
self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
|
|
1917
|
-
|
|
1918
|
-
for m in self.modules():
|
|
1919
|
-
if isinstance(m, nn.Conv2d):
|
|
1920
|
-
init.kaiming_uniform_(m.weight, a=1)
|
|
1921
|
-
init.constant_(m.bias, 0)
|
|
1922
|
-
|
|
1923
|
-
def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]):
|
|
1924
|
-
# here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with
|
|
1925
|
-
# the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
|
|
1926
|
-
# We expand the projected feature map to match the number of heads.
|
|
1927
|
-
x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
|
|
1928
|
-
|
|
1929
|
-
x = self.lay1(x)
|
|
1930
|
-
x = self.gn1(x)
|
|
1931
|
-
x = nn.functional.relu(x)
|
|
1932
|
-
x = self.lay2(x)
|
|
1933
|
-
x = self.gn2(x)
|
|
1934
|
-
x = nn.functional.relu(x)
|
|
1935
|
-
|
|
1936
|
-
cur_fpn = self.adapter1(fpns[0])
|
|
1937
|
-
if cur_fpn.size(0) != x.size(0):
|
|
1938
|
-
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
|
1939
|
-
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
1940
|
-
x = self.lay3(x)
|
|
1941
|
-
x = self.gn3(x)
|
|
1942
|
-
x = nn.functional.relu(x)
|
|
1943
|
-
|
|
1944
|
-
cur_fpn = self.adapter2(fpns[1])
|
|
1945
|
-
if cur_fpn.size(0) != x.size(0):
|
|
1946
|
-
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
|
1947
|
-
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
1948
|
-
x = self.lay4(x)
|
|
1949
|
-
x = self.gn4(x)
|
|
1950
|
-
x = nn.functional.relu(x)
|
|
1951
|
-
|
|
1952
|
-
cur_fpn = self.adapter3(fpns[2])
|
|
1953
|
-
if cur_fpn.size(0) != x.size(0):
|
|
1954
|
-
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
|
1955
|
-
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
1956
|
-
x = self.lay5(x)
|
|
1957
|
-
x = self.gn5(x)
|
|
1958
|
-
x = nn.functional.relu(x)
|
|
1959
|
-
|
|
1960
|
-
x = self.out_lay(x)
|
|
1961
|
-
return x
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
# Copied from transformers.models.detr.modeling_detr.DetrMHAttentionMap with Detr->ConditionalDetr
|
|
1965
|
-
class ConditionalDetrMHAttentionMap(nn.Module):
|
|
1966
|
-
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
|
1967
|
-
|
|
1968
|
-
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
|
|
1969
|
-
super().__init__()
|
|
1970
|
-
self.num_heads = num_heads
|
|
1971
|
-
self.hidden_dim = hidden_dim
|
|
1972
|
-
self.dropout = nn.Dropout(dropout)
|
|
1973
|
-
|
|
1974
|
-
self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
|
1975
|
-
self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
|
1976
|
-
|
|
1977
|
-
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
|
|
1978
|
-
|
|
1979
|
-
def forward(self, q, k, mask: Optional[Tensor] = None):
|
|
1980
|
-
q = self.q_linear(q)
|
|
1981
|
-
k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
|
|
1982
|
-
queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
|
|
1983
|
-
keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
|
|
1984
|
-
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
|
|
1985
|
-
|
|
1986
|
-
if mask is not None:
|
|
1987
|
-
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
|
1988
|
-
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
|
|
1989
|
-
weights = self.dropout(weights)
|
|
1990
|
-
return weights
|
|
1991
|
-
|
|
1992
|
-
|
|
1993
1843
|
__all__ = [
|
|
1994
1844
|
"ConditionalDetrForObjectDetection",
|
|
1995
1845
|
"ConditionalDetrForSegmentation",
|