transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +11 -37
- transformers/activations.py +2 -2
- transformers/audio_utils.py +32 -32
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +26 -126
- transformers/cli/chat.py +3 -3
- transformers/cli/serve.py +13 -10
- transformers/cli/transformers.py +2 -1
- transformers/configuration_utils.py +22 -92
- transformers/conversion_mapping.py +150 -26
- transformers/convert_slow_tokenizer.py +9 -12
- transformers/core_model_loading.py +217 -129
- transformers/data/processors/glue.py +0 -1
- transformers/data/processors/utils.py +0 -1
- transformers/data/processors/xnli.py +0 -1
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +10 -11
- transformers/distributed/configuration_utils.py +1 -2
- transformers/dynamic_module_utils.py +23 -23
- transformers/feature_extraction_sequence_utils.py +19 -23
- transformers/feature_extraction_utils.py +14 -14
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +2 -4
- transformers/generation/configuration_utils.py +54 -39
- transformers/generation/continuous_batching/__init__.py +0 -1
- transformers/generation/continuous_batching/cache.py +74 -44
- transformers/generation/continuous_batching/cache_manager.py +28 -28
- transformers/generation/continuous_batching/continuous_api.py +133 -414
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +77 -19
- transformers/generation/continuous_batching/scheduler.py +154 -104
- transformers/generation/logits_process.py +10 -133
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/streamers.py +0 -1
- transformers/generation/utils.py +91 -121
- transformers/generation/watermarking.py +2 -3
- transformers/hf_argparser.py +9 -13
- transformers/hyperparameter_search.py +1 -2
- transformers/image_processing_base.py +9 -9
- transformers/image_processing_utils.py +11 -15
- transformers/image_processing_utils_fast.py +70 -71
- transformers/image_transforms.py +73 -42
- transformers/image_utils.py +30 -37
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/awq.py +1 -3
- transformers/integrations/deepspeed.py +146 -4
- transformers/integrations/eetq.py +0 -1
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/fbgemm_fp8.py +1 -2
- transformers/integrations/finegrained_fp8.py +149 -13
- transformers/integrations/flash_attention.py +3 -8
- transformers/integrations/flex_attention.py +1 -1
- transformers/integrations/fp_quant.py +4 -6
- transformers/integrations/ggml.py +0 -1
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/integration_utils.py +2 -3
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +52 -40
- transformers/integrations/peft.py +488 -176
- transformers/integrations/quark.py +2 -4
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/integrations/torchao.py +4 -6
- transformers/loss/loss_lw_detr.py +356 -0
- transformers/loss/loss_utils.py +2 -0
- transformers/masking_utils.py +199 -59
- transformers/model_debugging_utils.py +4 -5
- transformers/modelcard.py +14 -192
- transformers/modeling_attn_mask_utils.py +19 -19
- transformers/modeling_flash_attention_utils.py +28 -29
- transformers/modeling_gguf_pytorch_utils.py +5 -5
- transformers/modeling_layers.py +21 -22
- transformers/modeling_outputs.py +242 -253
- transformers/modeling_rope_utils.py +32 -32
- transformers/modeling_utils.py +416 -438
- transformers/models/__init__.py +10 -0
- transformers/models/afmoe/configuration_afmoe.py +40 -33
- transformers/models/afmoe/modeling_afmoe.py +38 -41
- transformers/models/afmoe/modular_afmoe.py +23 -25
- transformers/models/aimv2/configuration_aimv2.py +2 -10
- transformers/models/aimv2/modeling_aimv2.py +46 -45
- transformers/models/aimv2/modular_aimv2.py +13 -19
- transformers/models/albert/configuration_albert.py +8 -2
- transformers/models/albert/modeling_albert.py +70 -72
- transformers/models/albert/tokenization_albert.py +1 -4
- transformers/models/align/configuration_align.py +8 -6
- transformers/models/align/modeling_align.py +83 -86
- transformers/models/align/processing_align.py +2 -30
- transformers/models/altclip/configuration_altclip.py +4 -7
- transformers/models/altclip/modeling_altclip.py +106 -103
- transformers/models/altclip/processing_altclip.py +2 -15
- transformers/models/apertus/__init__.py +0 -1
- transformers/models/apertus/configuration_apertus.py +23 -28
- transformers/models/apertus/modeling_apertus.py +35 -38
- transformers/models/apertus/modular_apertus.py +36 -40
- transformers/models/arcee/configuration_arcee.py +25 -30
- transformers/models/arcee/modeling_arcee.py +35 -38
- transformers/models/arcee/modular_arcee.py +20 -23
- transformers/models/aria/configuration_aria.py +31 -44
- transformers/models/aria/image_processing_aria.py +25 -27
- transformers/models/aria/modeling_aria.py +102 -102
- transformers/models/aria/modular_aria.py +111 -124
- transformers/models/aria/processing_aria.py +28 -35
- transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
- transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
- transformers/models/audioflamingo3/__init__.py +0 -1
- transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
- transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
- transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
- transformers/models/auto/auto_factory.py +12 -11
- transformers/models/auto/configuration_auto.py +48 -5
- transformers/models/auto/feature_extraction_auto.py +5 -7
- transformers/models/auto/image_processing_auto.py +30 -39
- transformers/models/auto/modeling_auto.py +33 -199
- transformers/models/auto/processing_auto.py +11 -19
- transformers/models/auto/tokenization_auto.py +38 -37
- transformers/models/auto/video_processing_auto.py +7 -8
- transformers/models/autoformer/configuration_autoformer.py +4 -7
- transformers/models/autoformer/modeling_autoformer.py +100 -101
- transformers/models/aya_vision/configuration_aya_vision.py +4 -1
- transformers/models/aya_vision/modeling_aya_vision.py +64 -99
- transformers/models/aya_vision/modular_aya_vision.py +46 -74
- transformers/models/aya_vision/processing_aya_vision.py +25 -53
- transformers/models/bamba/configuration_bamba.py +46 -39
- transformers/models/bamba/modeling_bamba.py +83 -119
- transformers/models/bamba/modular_bamba.py +70 -109
- transformers/models/bark/configuration_bark.py +6 -8
- transformers/models/bark/generation_configuration_bark.py +3 -5
- transformers/models/bark/modeling_bark.py +64 -65
- transformers/models/bark/processing_bark.py +19 -41
- transformers/models/bart/configuration_bart.py +9 -5
- transformers/models/bart/modeling_bart.py +124 -129
- transformers/models/barthez/tokenization_barthez.py +1 -4
- transformers/models/bartpho/tokenization_bartpho.py +6 -7
- transformers/models/beit/configuration_beit.py +2 -15
- transformers/models/beit/image_processing_beit.py +53 -56
- transformers/models/beit/image_processing_beit_fast.py +11 -12
- transformers/models/beit/modeling_beit.py +65 -62
- transformers/models/bert/configuration_bert.py +12 -2
- transformers/models/bert/modeling_bert.py +117 -152
- transformers/models/bert/tokenization_bert.py +2 -4
- transformers/models/bert/tokenization_bert_legacy.py +3 -5
- transformers/models/bert_generation/configuration_bert_generation.py +17 -2
- transformers/models/bert_generation/modeling_bert_generation.py +53 -55
- transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
- transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
- transformers/models/bertweet/tokenization_bertweet.py +1 -3
- transformers/models/big_bird/configuration_big_bird.py +12 -9
- transformers/models/big_bird/modeling_big_bird.py +107 -124
- transformers/models/big_bird/tokenization_big_bird.py +1 -4
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
- transformers/models/biogpt/configuration_biogpt.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +73 -79
- transformers/models/biogpt/modular_biogpt.py +60 -66
- transformers/models/biogpt/tokenization_biogpt.py +3 -5
- transformers/models/bit/configuration_bit.py +2 -5
- transformers/models/bit/image_processing_bit.py +21 -24
- transformers/models/bit/image_processing_bit_fast.py +0 -1
- transformers/models/bit/modeling_bit.py +15 -16
- transformers/models/bitnet/configuration_bitnet.py +23 -28
- transformers/models/bitnet/modeling_bitnet.py +34 -38
- transformers/models/bitnet/modular_bitnet.py +7 -10
- transformers/models/blenderbot/configuration_blenderbot.py +8 -5
- transformers/models/blenderbot/modeling_blenderbot.py +68 -99
- transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
- transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
- transformers/models/blip/configuration_blip.py +9 -10
- transformers/models/blip/image_processing_blip.py +17 -20
- transformers/models/blip/image_processing_blip_fast.py +0 -1
- transformers/models/blip/modeling_blip.py +115 -108
- transformers/models/blip/modeling_blip_text.py +63 -65
- transformers/models/blip/processing_blip.py +5 -36
- transformers/models/blip_2/configuration_blip_2.py +2 -2
- transformers/models/blip_2/modeling_blip_2.py +145 -121
- transformers/models/blip_2/processing_blip_2.py +8 -38
- transformers/models/bloom/configuration_bloom.py +5 -2
- transformers/models/bloom/modeling_bloom.py +60 -60
- transformers/models/blt/configuration_blt.py +94 -86
- transformers/models/blt/modeling_blt.py +93 -90
- transformers/models/blt/modular_blt.py +127 -69
- transformers/models/bridgetower/configuration_bridgetower.py +7 -2
- transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
- transformers/models/bridgetower/modeling_bridgetower.py +136 -124
- transformers/models/bridgetower/processing_bridgetower.py +2 -16
- transformers/models/bros/configuration_bros.py +24 -18
- transformers/models/bros/modeling_bros.py +78 -80
- transformers/models/bros/processing_bros.py +2 -12
- transformers/models/byt5/tokenization_byt5.py +4 -6
- transformers/models/camembert/configuration_camembert.py +8 -2
- transformers/models/camembert/modeling_camembert.py +97 -99
- transformers/models/camembert/modular_camembert.py +51 -54
- transformers/models/camembert/tokenization_camembert.py +1 -4
- transformers/models/canine/configuration_canine.py +4 -2
- transformers/models/canine/modeling_canine.py +73 -75
- transformers/models/canine/tokenization_canine.py +0 -1
- transformers/models/chameleon/configuration_chameleon.py +29 -34
- transformers/models/chameleon/image_processing_chameleon.py +21 -24
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
- transformers/models/chameleon/modeling_chameleon.py +135 -92
- transformers/models/chameleon/processing_chameleon.py +16 -41
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
- transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
- transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
- transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
- transformers/models/clap/configuration_clap.py +4 -9
- transformers/models/clap/feature_extraction_clap.py +9 -10
- transformers/models/clap/modeling_clap.py +109 -111
- transformers/models/clap/processing_clap.py +2 -15
- transformers/models/clip/configuration_clip.py +4 -2
- transformers/models/clip/image_processing_clip.py +21 -24
- transformers/models/clip/image_processing_clip_fast.py +9 -1
- transformers/models/clip/modeling_clip.py +70 -68
- transformers/models/clip/processing_clip.py +2 -14
- transformers/models/clip/tokenization_clip.py +2 -5
- transformers/models/clipseg/configuration_clipseg.py +4 -2
- transformers/models/clipseg/modeling_clipseg.py +113 -112
- transformers/models/clipseg/processing_clipseg.py +19 -42
- transformers/models/clvp/configuration_clvp.py +15 -5
- transformers/models/clvp/feature_extraction_clvp.py +7 -10
- transformers/models/clvp/modeling_clvp.py +138 -145
- transformers/models/clvp/number_normalizer.py +1 -2
- transformers/models/clvp/processing_clvp.py +3 -20
- transformers/models/clvp/tokenization_clvp.py +0 -1
- transformers/models/code_llama/tokenization_code_llama.py +3 -6
- transformers/models/codegen/configuration_codegen.py +4 -4
- transformers/models/codegen/modeling_codegen.py +50 -49
- transformers/models/codegen/tokenization_codegen.py +5 -6
- transformers/models/cohere/configuration_cohere.py +25 -30
- transformers/models/cohere/modeling_cohere.py +39 -42
- transformers/models/cohere/modular_cohere.py +27 -31
- transformers/models/cohere/tokenization_cohere.py +5 -6
- transformers/models/cohere2/configuration_cohere2.py +27 -32
- transformers/models/cohere2/modeling_cohere2.py +38 -41
- transformers/models/cohere2/modular_cohere2.py +48 -52
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
- transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
- transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
- transformers/models/colpali/configuration_colpali.py +0 -1
- transformers/models/colpali/modeling_colpali.py +14 -16
- transformers/models/colpali/modular_colpali.py +11 -51
- transformers/models/colpali/processing_colpali.py +14 -52
- transformers/models/colqwen2/modeling_colqwen2.py +27 -28
- transformers/models/colqwen2/modular_colqwen2.py +36 -74
- transformers/models/colqwen2/processing_colqwen2.py +16 -52
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
- transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
- transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
- transformers/models/convbert/configuration_convbert.py +11 -8
- transformers/models/convbert/modeling_convbert.py +85 -87
- transformers/models/convbert/tokenization_convbert.py +0 -1
- transformers/models/convnext/configuration_convnext.py +2 -5
- transformers/models/convnext/image_processing_convnext.py +18 -21
- transformers/models/convnext/image_processing_convnext_fast.py +7 -8
- transformers/models/convnext/modeling_convnext.py +12 -14
- transformers/models/convnextv2/configuration_convnextv2.py +2 -5
- transformers/models/convnextv2/modeling_convnextv2.py +12 -14
- transformers/models/cpm/tokenization_cpm.py +6 -7
- transformers/models/cpm/tokenization_cpm_fast.py +3 -5
- transformers/models/cpmant/configuration_cpmant.py +4 -1
- transformers/models/cpmant/modeling_cpmant.py +38 -40
- transformers/models/cpmant/tokenization_cpmant.py +1 -3
- transformers/models/csm/configuration_csm.py +58 -66
- transformers/models/csm/generation_csm.py +13 -14
- transformers/models/csm/modeling_csm.py +81 -84
- transformers/models/csm/modular_csm.py +56 -58
- transformers/models/csm/processing_csm.py +25 -68
- transformers/models/ctrl/configuration_ctrl.py +16 -1
- transformers/models/ctrl/modeling_ctrl.py +51 -66
- transformers/models/ctrl/tokenization_ctrl.py +0 -1
- transformers/models/cvt/configuration_cvt.py +0 -1
- transformers/models/cvt/modeling_cvt.py +13 -15
- transformers/models/cwm/__init__.py +0 -1
- transformers/models/cwm/configuration_cwm.py +8 -12
- transformers/models/cwm/modeling_cwm.py +36 -38
- transformers/models/cwm/modular_cwm.py +10 -12
- transformers/models/d_fine/configuration_d_fine.py +10 -57
- transformers/models/d_fine/modeling_d_fine.py +786 -927
- transformers/models/d_fine/modular_d_fine.py +339 -417
- transformers/models/dab_detr/configuration_dab_detr.py +22 -49
- transformers/models/dab_detr/modeling_dab_detr.py +79 -77
- transformers/models/dac/configuration_dac.py +0 -1
- transformers/models/dac/feature_extraction_dac.py +6 -9
- transformers/models/dac/modeling_dac.py +22 -24
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
- transformers/models/data2vec/configuration_data2vec_text.py +11 -3
- transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
- transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
- transformers/models/data2vec/modeling_data2vec_text.py +97 -99
- transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
- transformers/models/data2vec/modular_data2vec_audio.py +6 -1
- transformers/models/data2vec/modular_data2vec_text.py +51 -54
- transformers/models/dbrx/configuration_dbrx.py +29 -22
- transformers/models/dbrx/modeling_dbrx.py +45 -48
- transformers/models/dbrx/modular_dbrx.py +37 -39
- transformers/models/deberta/configuration_deberta.py +6 -1
- transformers/models/deberta/modeling_deberta.py +57 -60
- transformers/models/deberta/tokenization_deberta.py +2 -5
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
- transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
- transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
- transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
- transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
- transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
- transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
- transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
- transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
- transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
- transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
- transformers/models/deit/configuration_deit.py +0 -1
- transformers/models/deit/image_processing_deit.py +18 -21
- transformers/models/deit/image_processing_deit_fast.py +0 -1
- transformers/models/deit/modeling_deit.py +27 -25
- transformers/models/depth_anything/configuration_depth_anything.py +12 -43
- transformers/models/depth_anything/modeling_depth_anything.py +10 -11
- transformers/models/depth_pro/configuration_depth_pro.py +0 -1
- transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
- transformers/models/depth_pro/modeling_depth_pro.py +29 -27
- transformers/models/detr/configuration_detr.py +18 -50
- transformers/models/detr/image_processing_detr.py +64 -66
- transformers/models/detr/image_processing_detr_fast.py +33 -34
- transformers/models/detr/modeling_detr.py +748 -789
- transformers/models/dia/configuration_dia.py +9 -15
- transformers/models/dia/feature_extraction_dia.py +6 -9
- transformers/models/dia/generation_dia.py +48 -53
- transformers/models/dia/modeling_dia.py +68 -71
- transformers/models/dia/modular_dia.py +56 -58
- transformers/models/dia/processing_dia.py +39 -29
- transformers/models/dia/tokenization_dia.py +3 -6
- transformers/models/diffllama/configuration_diffllama.py +25 -30
- transformers/models/diffllama/modeling_diffllama.py +45 -53
- transformers/models/diffllama/modular_diffllama.py +18 -25
- transformers/models/dinat/configuration_dinat.py +2 -5
- transformers/models/dinat/modeling_dinat.py +47 -48
- transformers/models/dinov2/configuration_dinov2.py +2 -5
- transformers/models/dinov2/modeling_dinov2.py +20 -21
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
- transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
- transformers/models/distilbert/configuration_distilbert.py +8 -2
- transformers/models/distilbert/modeling_distilbert.py +47 -49
- transformers/models/distilbert/tokenization_distilbert.py +0 -1
- transformers/models/doge/__init__.py +0 -1
- transformers/models/doge/configuration_doge.py +42 -35
- transformers/models/doge/modeling_doge.py +46 -49
- transformers/models/doge/modular_doge.py +77 -68
- transformers/models/donut/configuration_donut_swin.py +0 -1
- transformers/models/donut/image_processing_donut.py +26 -29
- transformers/models/donut/image_processing_donut_fast.py +9 -14
- transformers/models/donut/modeling_donut_swin.py +44 -46
- transformers/models/donut/processing_donut.py +5 -26
- transformers/models/dots1/configuration_dots1.py +43 -36
- transformers/models/dots1/modeling_dots1.py +35 -38
- transformers/models/dots1/modular_dots1.py +0 -1
- transformers/models/dpr/configuration_dpr.py +19 -2
- transformers/models/dpr/modeling_dpr.py +37 -39
- transformers/models/dpr/tokenization_dpr.py +7 -9
- transformers/models/dpr/tokenization_dpr_fast.py +7 -9
- transformers/models/dpt/configuration_dpt.py +23 -66
- transformers/models/dpt/image_processing_dpt.py +65 -66
- transformers/models/dpt/image_processing_dpt_fast.py +18 -19
- transformers/models/dpt/modeling_dpt.py +38 -36
- transformers/models/dpt/modular_dpt.py +14 -15
- transformers/models/edgetam/configuration_edgetam.py +1 -2
- transformers/models/edgetam/modeling_edgetam.py +87 -89
- transformers/models/edgetam/modular_edgetam.py +7 -13
- transformers/models/edgetam_video/__init__.py +0 -1
- transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
- transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
- transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
- transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
- transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
- transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
- transformers/models/efficientnet/configuration_efficientnet.py +0 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
- transformers/models/efficientnet/modeling_efficientnet.py +12 -14
- transformers/models/electra/configuration_electra.py +13 -3
- transformers/models/electra/modeling_electra.py +107 -109
- transformers/models/emu3/configuration_emu3.py +17 -17
- transformers/models/emu3/image_processing_emu3.py +44 -39
- transformers/models/emu3/modeling_emu3.py +143 -109
- transformers/models/emu3/modular_emu3.py +109 -73
- transformers/models/emu3/processing_emu3.py +18 -43
- transformers/models/encodec/configuration_encodec.py +2 -4
- transformers/models/encodec/feature_extraction_encodec.py +10 -13
- transformers/models/encodec/modeling_encodec.py +25 -29
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
- transformers/models/eomt/configuration_eomt.py +12 -14
- transformers/models/eomt/image_processing_eomt.py +53 -55
- transformers/models/eomt/image_processing_eomt_fast.py +18 -19
- transformers/models/eomt/modeling_eomt.py +19 -21
- transformers/models/eomt/modular_eomt.py +28 -30
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -3
- transformers/models/ernie/modeling_ernie.py +127 -162
- transformers/models/ernie/modular_ernie.py +91 -103
- transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
- transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
- transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
- transformers/models/esm/configuration_esm.py +11 -15
- transformers/models/esm/modeling_esm.py +35 -37
- transformers/models/esm/modeling_esmfold.py +43 -50
- transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
- transformers/models/esm/openfold_utils/loss.py +1 -2
- transformers/models/esm/openfold_utils/protein.py +15 -16
- transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
- transformers/models/esm/tokenization_esm.py +2 -4
- transformers/models/evolla/configuration_evolla.py +50 -40
- transformers/models/evolla/modeling_evolla.py +69 -68
- transformers/models/evolla/modular_evolla.py +50 -48
- transformers/models/evolla/processing_evolla.py +23 -35
- transformers/models/exaone4/configuration_exaone4.py +27 -27
- transformers/models/exaone4/modeling_exaone4.py +36 -39
- transformers/models/exaone4/modular_exaone4.py +51 -50
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +31 -26
- transformers/models/falcon/modeling_falcon.py +76 -84
- transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
- transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
- transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
- transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
- transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
- transformers/models/flaubert/configuration_flaubert.py +10 -5
- transformers/models/flaubert/modeling_flaubert.py +125 -129
- transformers/models/flaubert/tokenization_flaubert.py +3 -5
- transformers/models/flava/configuration_flava.py +9 -9
- transformers/models/flava/image_processing_flava.py +66 -67
- transformers/models/flava/image_processing_flava_fast.py +46 -47
- transformers/models/flava/modeling_flava.py +144 -135
- transformers/models/flava/processing_flava.py +2 -12
- transformers/models/flex_olmo/__init__.py +0 -1
- transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
- transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
- transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
- transformers/models/florence2/configuration_florence2.py +4 -1
- transformers/models/florence2/modeling_florence2.py +96 -72
- transformers/models/florence2/modular_florence2.py +100 -107
- transformers/models/florence2/processing_florence2.py +18 -47
- transformers/models/fnet/configuration_fnet.py +6 -2
- transformers/models/fnet/modeling_fnet.py +69 -80
- transformers/models/fnet/tokenization_fnet.py +0 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -5
- transformers/models/focalnet/modeling_focalnet.py +49 -48
- transformers/models/fsmt/configuration_fsmt.py +12 -17
- transformers/models/fsmt/modeling_fsmt.py +47 -48
- transformers/models/fsmt/tokenization_fsmt.py +3 -5
- transformers/models/funnel/configuration_funnel.py +8 -1
- transformers/models/funnel/modeling_funnel.py +91 -93
- transformers/models/funnel/tokenization_funnel.py +2 -5
- transformers/models/fuyu/configuration_fuyu.py +28 -34
- transformers/models/fuyu/image_processing_fuyu.py +29 -31
- transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
- transformers/models/fuyu/modeling_fuyu.py +50 -52
- transformers/models/fuyu/processing_fuyu.py +9 -36
- transformers/models/gemma/configuration_gemma.py +25 -30
- transformers/models/gemma/modeling_gemma.py +36 -38
- transformers/models/gemma/modular_gemma.py +33 -36
- transformers/models/gemma/tokenization_gemma.py +3 -6
- transformers/models/gemma2/configuration_gemma2.py +30 -35
- transformers/models/gemma2/modeling_gemma2.py +38 -41
- transformers/models/gemma2/modular_gemma2.py +63 -67
- transformers/models/gemma3/configuration_gemma3.py +53 -48
- transformers/models/gemma3/image_processing_gemma3.py +29 -31
- transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
- transformers/models/gemma3/modeling_gemma3.py +123 -122
- transformers/models/gemma3/modular_gemma3.py +128 -125
- transformers/models/gemma3/processing_gemma3.py +5 -5
- transformers/models/gemma3n/configuration_gemma3n.py +42 -30
- transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
- transformers/models/gemma3n/modeling_gemma3n.py +166 -147
- transformers/models/gemma3n/modular_gemma3n.py +176 -148
- transformers/models/gemma3n/processing_gemma3n.py +12 -26
- transformers/models/git/configuration_git.py +5 -8
- transformers/models/git/modeling_git.py +115 -127
- transformers/models/git/processing_git.py +2 -14
- transformers/models/glm/configuration_glm.py +26 -30
- transformers/models/glm/modeling_glm.py +36 -39
- transformers/models/glm/modular_glm.py +4 -7
- transformers/models/glm4/configuration_glm4.py +26 -30
- transformers/models/glm4/modeling_glm4.py +39 -41
- transformers/models/glm4/modular_glm4.py +8 -10
- transformers/models/glm46v/configuration_glm46v.py +4 -1
- transformers/models/glm46v/image_processing_glm46v.py +40 -38
- transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
- transformers/models/glm46v/modeling_glm46v.py +138 -93
- transformers/models/glm46v/modular_glm46v.py +5 -3
- transformers/models/glm46v/processing_glm46v.py +7 -41
- transformers/models/glm46v/video_processing_glm46v.py +9 -11
- transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
- transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
- transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
- transformers/models/glm4_moe_lite/__init__.py +28 -0
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
- transformers/models/glm4v/configuration_glm4v.py +25 -24
- transformers/models/glm4v/image_processing_glm4v.py +39 -38
- transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
- transformers/models/glm4v/modeling_glm4v.py +249 -210
- transformers/models/glm4v/modular_glm4v.py +211 -230
- transformers/models/glm4v/processing_glm4v.py +7 -41
- transformers/models/glm4v/video_processing_glm4v.py +9 -11
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
- transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
- transformers/models/glm_image/__init__.py +31 -0
- transformers/models/glm_image/configuration_glm_image.py +358 -0
- transformers/models/glm_image/image_processing_glm_image.py +503 -0
- transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
- transformers/models/glm_image/modeling_glm_image.py +1691 -0
- transformers/models/glm_image/modular_glm_image.py +1640 -0
- transformers/models/glm_image/processing_glm_image.py +265 -0
- transformers/models/glm_ocr/__init__.py +28 -0
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/__init__.py +0 -1
- transformers/models/glmasr/configuration_glmasr.py +0 -1
- transformers/models/glmasr/modeling_glmasr.py +51 -46
- transformers/models/glmasr/modular_glmasr.py +39 -29
- transformers/models/glmasr/processing_glmasr.py +7 -8
- transformers/models/glpn/configuration_glpn.py +0 -1
- transformers/models/glpn/image_processing_glpn.py +11 -12
- transformers/models/glpn/image_processing_glpn_fast.py +11 -12
- transformers/models/glpn/modeling_glpn.py +14 -14
- transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
- transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
- transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
- transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
- transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
- transformers/models/gpt2/configuration_gpt2.py +13 -2
- transformers/models/gpt2/modeling_gpt2.py +111 -113
- transformers/models/gpt2/tokenization_gpt2.py +6 -9
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
- transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
- transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
- transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
- transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
- transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
- transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
- transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
- transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
- transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
- transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
- transformers/models/gptj/configuration_gptj.py +4 -5
- transformers/models/gptj/modeling_gptj.py +85 -88
- transformers/models/granite/configuration_granite.py +28 -33
- transformers/models/granite/modeling_granite.py +43 -45
- transformers/models/granite/modular_granite.py +29 -31
- transformers/models/granite_speech/configuration_granite_speech.py +0 -1
- transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
- transformers/models/granite_speech/modeling_granite_speech.py +84 -60
- transformers/models/granite_speech/processing_granite_speech.py +11 -4
- transformers/models/granitemoe/configuration_granitemoe.py +31 -36
- transformers/models/granitemoe/modeling_granitemoe.py +39 -41
- transformers/models/granitemoe/modular_granitemoe.py +21 -23
- transformers/models/granitemoehybrid/__init__.py +0 -1
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
- transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
- transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
- transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
- transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
- transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
- transformers/models/groupvit/configuration_groupvit.py +4 -2
- transformers/models/groupvit/modeling_groupvit.py +98 -92
- transformers/models/helium/configuration_helium.py +25 -29
- transformers/models/helium/modeling_helium.py +37 -40
- transformers/models/helium/modular_helium.py +3 -7
- transformers/models/herbert/tokenization_herbert.py +4 -6
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
- transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
- transformers/models/hiera/configuration_hiera.py +2 -5
- transformers/models/hiera/modeling_hiera.py +71 -70
- transformers/models/hubert/configuration_hubert.py +4 -2
- transformers/models/hubert/modeling_hubert.py +42 -41
- transformers/models/hubert/modular_hubert.py +8 -11
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
- transformers/models/ibert/configuration_ibert.py +4 -2
- transformers/models/ibert/modeling_ibert.py +60 -62
- transformers/models/ibert/quant_modules.py +0 -1
- transformers/models/idefics/configuration_idefics.py +5 -8
- transformers/models/idefics/image_processing_idefics.py +13 -15
- transformers/models/idefics/modeling_idefics.py +63 -65
- transformers/models/idefics/perceiver.py +1 -3
- transformers/models/idefics/processing_idefics.py +32 -48
- transformers/models/idefics/vision.py +27 -28
- transformers/models/idefics2/configuration_idefics2.py +1 -3
- transformers/models/idefics2/image_processing_idefics2.py +31 -32
- transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
- transformers/models/idefics2/modeling_idefics2.py +126 -106
- transformers/models/idefics2/processing_idefics2.py +10 -68
- transformers/models/idefics3/configuration_idefics3.py +1 -4
- transformers/models/idefics3/image_processing_idefics3.py +42 -43
- transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
- transformers/models/idefics3/modeling_idefics3.py +113 -92
- transformers/models/idefics3/processing_idefics3.py +15 -69
- transformers/models/ijepa/configuration_ijepa.py +0 -1
- transformers/models/ijepa/modeling_ijepa.py +13 -14
- transformers/models/ijepa/modular_ijepa.py +5 -7
- transformers/models/imagegpt/configuration_imagegpt.py +9 -2
- transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
- transformers/models/imagegpt/modeling_imagegpt.py +65 -62
- transformers/models/informer/configuration_informer.py +6 -9
- transformers/models/informer/modeling_informer.py +87 -89
- transformers/models/informer/modular_informer.py +13 -16
- transformers/models/instructblip/configuration_instructblip.py +2 -2
- transformers/models/instructblip/modeling_instructblip.py +104 -79
- transformers/models/instructblip/processing_instructblip.py +10 -36
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
- transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
- transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
- transformers/models/internvl/configuration_internvl.py +5 -1
- transformers/models/internvl/modeling_internvl.py +76 -98
- transformers/models/internvl/modular_internvl.py +45 -59
- transformers/models/internvl/processing_internvl.py +12 -45
- transformers/models/internvl/video_processing_internvl.py +10 -11
- transformers/models/jais2/configuration_jais2.py +25 -29
- transformers/models/jais2/modeling_jais2.py +36 -38
- transformers/models/jais2/modular_jais2.py +20 -22
- transformers/models/jamba/configuration_jamba.py +5 -8
- transformers/models/jamba/modeling_jamba.py +47 -50
- transformers/models/jamba/modular_jamba.py +40 -41
- transformers/models/janus/configuration_janus.py +0 -1
- transformers/models/janus/image_processing_janus.py +37 -39
- transformers/models/janus/image_processing_janus_fast.py +20 -21
- transformers/models/janus/modeling_janus.py +103 -188
- transformers/models/janus/modular_janus.py +122 -83
- transformers/models/janus/processing_janus.py +17 -43
- transformers/models/jetmoe/configuration_jetmoe.py +26 -27
- transformers/models/jetmoe/modeling_jetmoe.py +42 -45
- transformers/models/jetmoe/modular_jetmoe.py +33 -36
- transformers/models/kosmos2/configuration_kosmos2.py +10 -9
- transformers/models/kosmos2/modeling_kosmos2.py +199 -178
- transformers/models/kosmos2/processing_kosmos2.py +40 -55
- transformers/models/kosmos2_5/__init__.py +0 -1
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
- transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
- transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
- transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
- transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
- transformers/models/lasr/configuration_lasr.py +3 -7
- transformers/models/lasr/feature_extraction_lasr.py +10 -12
- transformers/models/lasr/modeling_lasr.py +21 -24
- transformers/models/lasr/modular_lasr.py +11 -13
- transformers/models/lasr/processing_lasr.py +12 -6
- transformers/models/lasr/tokenization_lasr.py +2 -4
- transformers/models/layoutlm/configuration_layoutlm.py +14 -2
- transformers/models/layoutlm/modeling_layoutlm.py +70 -72
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
- transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
- transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
- transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
- transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
- transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
- transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
- transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
- transformers/models/led/configuration_led.py +8 -12
- transformers/models/led/modeling_led.py +113 -267
- transformers/models/levit/configuration_levit.py +0 -1
- transformers/models/levit/image_processing_levit.py +19 -21
- transformers/models/levit/image_processing_levit_fast.py +4 -5
- transformers/models/levit/modeling_levit.py +17 -19
- transformers/models/lfm2/configuration_lfm2.py +27 -30
- transformers/models/lfm2/modeling_lfm2.py +46 -48
- transformers/models/lfm2/modular_lfm2.py +32 -32
- transformers/models/lfm2_moe/__init__.py +0 -1
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
- transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
- transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
- transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
- transformers/models/lightglue/image_processing_lightglue.py +16 -15
- transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
- transformers/models/lightglue/modeling_lightglue.py +31 -33
- transformers/models/lightglue/modular_lightglue.py +31 -31
- transformers/models/lighton_ocr/__init__.py +28 -0
- transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
- transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
- transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
- transformers/models/lilt/configuration_lilt.py +6 -2
- transformers/models/lilt/modeling_lilt.py +53 -55
- transformers/models/llama/configuration_llama.py +26 -31
- transformers/models/llama/modeling_llama.py +35 -38
- transformers/models/llama/tokenization_llama.py +2 -4
- transformers/models/llama4/configuration_llama4.py +87 -69
- transformers/models/llama4/image_processing_llama4_fast.py +11 -12
- transformers/models/llama4/modeling_llama4.py +116 -115
- transformers/models/llama4/processing_llama4.py +33 -57
- transformers/models/llava/configuration_llava.py +10 -1
- transformers/models/llava/image_processing_llava.py +25 -28
- transformers/models/llava/image_processing_llava_fast.py +9 -10
- transformers/models/llava/modeling_llava.py +73 -102
- transformers/models/llava/processing_llava.py +18 -51
- transformers/models/llava_next/configuration_llava_next.py +2 -2
- transformers/models/llava_next/image_processing_llava_next.py +43 -45
- transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
- transformers/models/llava_next/modeling_llava_next.py +103 -104
- transformers/models/llava_next/processing_llava_next.py +18 -47
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
- transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
- transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
- transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
- transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
- transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
- transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
- transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
- transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
- transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
- transformers/models/longcat_flash/__init__.py +0 -1
- transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
- transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
- transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
- transformers/models/longformer/configuration_longformer.py +5 -5
- transformers/models/longformer/modeling_longformer.py +99 -101
- transformers/models/longt5/configuration_longt5.py +9 -7
- transformers/models/longt5/modeling_longt5.py +45 -45
- transformers/models/luke/configuration_luke.py +8 -2
- transformers/models/luke/modeling_luke.py +179 -181
- transformers/models/luke/tokenization_luke.py +99 -105
- transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
- transformers/models/lw_detr/configuration_lw_detr.py +362 -0
- transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
- transformers/models/lw_detr/modular_lw_detr.py +1609 -0
- transformers/models/lxmert/configuration_lxmert.py +16 -1
- transformers/models/lxmert/modeling_lxmert.py +63 -74
- transformers/models/m2m_100/configuration_m2m_100.py +7 -9
- transformers/models/m2m_100/modeling_m2m_100.py +72 -74
- transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
- transformers/models/mamba/configuration_mamba.py +5 -3
- transformers/models/mamba/modeling_mamba.py +61 -70
- transformers/models/mamba2/configuration_mamba2.py +5 -8
- transformers/models/mamba2/modeling_mamba2.py +66 -79
- transformers/models/marian/configuration_marian.py +10 -5
- transformers/models/marian/modeling_marian.py +88 -90
- transformers/models/marian/tokenization_marian.py +6 -6
- transformers/models/markuplm/configuration_markuplm.py +4 -7
- transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
- transformers/models/markuplm/modeling_markuplm.py +63 -65
- transformers/models/markuplm/processing_markuplm.py +31 -38
- transformers/models/markuplm/tokenization_markuplm.py +67 -77
- transformers/models/mask2former/configuration_mask2former.py +14 -52
- transformers/models/mask2former/image_processing_mask2former.py +84 -85
- transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
- transformers/models/mask2former/modeling_mask2former.py +108 -104
- transformers/models/mask2former/modular_mask2former.py +6 -8
- transformers/models/maskformer/configuration_maskformer.py +17 -51
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
- transformers/models/maskformer/image_processing_maskformer.py +84 -85
- transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
- transformers/models/maskformer/modeling_maskformer.py +71 -67
- transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
- transformers/models/mbart/configuration_mbart.py +9 -5
- transformers/models/mbart/modeling_mbart.py +120 -119
- transformers/models/mbart/tokenization_mbart.py +2 -4
- transformers/models/mbart50/tokenization_mbart50.py +3 -5
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
- transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
- transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
- transformers/models/mgp_str/configuration_mgp_str.py +0 -1
- transformers/models/mgp_str/modeling_mgp_str.py +18 -18
- transformers/models/mgp_str/processing_mgp_str.py +3 -20
- transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
- transformers/models/mimi/configuration_mimi.py +42 -40
- transformers/models/mimi/modeling_mimi.py +116 -115
- transformers/models/minimax/__init__.py +0 -1
- transformers/models/minimax/configuration_minimax.py +40 -47
- transformers/models/minimax/modeling_minimax.py +46 -49
- transformers/models/minimax/modular_minimax.py +59 -65
- transformers/models/minimax_m2/__init__.py +28 -0
- transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
- transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
- transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
- transformers/models/ministral/configuration_ministral.py +25 -29
- transformers/models/ministral/modeling_ministral.py +35 -37
- transformers/models/ministral/modular_ministral.py +32 -37
- transformers/models/ministral3/configuration_ministral3.py +23 -26
- transformers/models/ministral3/modeling_ministral3.py +35 -37
- transformers/models/ministral3/modular_ministral3.py +7 -8
- transformers/models/mistral/configuration_mistral.py +24 -29
- transformers/models/mistral/modeling_mistral.py +35 -37
- transformers/models/mistral/modular_mistral.py +14 -15
- transformers/models/mistral3/configuration_mistral3.py +4 -1
- transformers/models/mistral3/modeling_mistral3.py +79 -82
- transformers/models/mistral3/modular_mistral3.py +66 -67
- transformers/models/mixtral/configuration_mixtral.py +32 -38
- transformers/models/mixtral/modeling_mixtral.py +39 -42
- transformers/models/mixtral/modular_mixtral.py +26 -29
- transformers/models/mlcd/configuration_mlcd.py +0 -1
- transformers/models/mlcd/modeling_mlcd.py +17 -17
- transformers/models/mlcd/modular_mlcd.py +16 -16
- transformers/models/mllama/configuration_mllama.py +10 -15
- transformers/models/mllama/image_processing_mllama.py +23 -25
- transformers/models/mllama/image_processing_mllama_fast.py +11 -11
- transformers/models/mllama/modeling_mllama.py +100 -103
- transformers/models/mllama/processing_mllama.py +6 -55
- transformers/models/mluke/tokenization_mluke.py +97 -103
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
- transformers/models/mobilebert/configuration_mobilebert.py +4 -2
- transformers/models/mobilebert/modeling_mobilebert.py +78 -88
- transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
- transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
- transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
- transformers/models/mobilevit/configuration_mobilevit.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
- transformers/models/mobilevit/modeling_mobilevit.py +21 -21
- transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
- transformers/models/modernbert/configuration_modernbert.py +76 -51
- transformers/models/modernbert/modeling_modernbert.py +188 -943
- transformers/models/modernbert/modular_modernbert.py +255 -978
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
- transformers/models/moonshine/configuration_moonshine.py +34 -31
- transformers/models/moonshine/modeling_moonshine.py +70 -72
- transformers/models/moonshine/modular_moonshine.py +91 -86
- transformers/models/moshi/configuration_moshi.py +46 -23
- transformers/models/moshi/modeling_moshi.py +134 -142
- transformers/models/mpnet/configuration_mpnet.py +6 -2
- transformers/models/mpnet/modeling_mpnet.py +55 -57
- transformers/models/mpnet/tokenization_mpnet.py +1 -4
- transformers/models/mpt/configuration_mpt.py +17 -9
- transformers/models/mpt/modeling_mpt.py +58 -60
- transformers/models/mra/configuration_mra.py +8 -2
- transformers/models/mra/modeling_mra.py +54 -56
- transformers/models/mt5/configuration_mt5.py +9 -6
- transformers/models/mt5/modeling_mt5.py +80 -85
- transformers/models/musicgen/configuration_musicgen.py +12 -8
- transformers/models/musicgen/modeling_musicgen.py +114 -116
- transformers/models/musicgen/processing_musicgen.py +3 -21
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
- transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
- transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
- transformers/models/mvp/configuration_mvp.py +8 -5
- transformers/models/mvp/modeling_mvp.py +121 -123
- transformers/models/myt5/tokenization_myt5.py +8 -10
- transformers/models/nanochat/configuration_nanochat.py +5 -8
- transformers/models/nanochat/modeling_nanochat.py +36 -39
- transformers/models/nanochat/modular_nanochat.py +16 -18
- transformers/models/nemotron/configuration_nemotron.py +25 -30
- transformers/models/nemotron/modeling_nemotron.py +53 -66
- transformers/models/nllb/tokenization_nllb.py +14 -14
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
- transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
- transformers/models/nougat/image_processing_nougat.py +29 -32
- transformers/models/nougat/image_processing_nougat_fast.py +12 -13
- transformers/models/nougat/processing_nougat.py +37 -39
- transformers/models/nougat/tokenization_nougat.py +5 -7
- transformers/models/nystromformer/configuration_nystromformer.py +8 -2
- transformers/models/nystromformer/modeling_nystromformer.py +61 -63
- transformers/models/olmo/configuration_olmo.py +23 -28
- transformers/models/olmo/modeling_olmo.py +35 -38
- transformers/models/olmo/modular_olmo.py +8 -12
- transformers/models/olmo2/configuration_olmo2.py +27 -32
- transformers/models/olmo2/modeling_olmo2.py +36 -39
- transformers/models/olmo2/modular_olmo2.py +36 -38
- transformers/models/olmo3/__init__.py +0 -1
- transformers/models/olmo3/configuration_olmo3.py +30 -34
- transformers/models/olmo3/modeling_olmo3.py +35 -38
- transformers/models/olmo3/modular_olmo3.py +44 -47
- transformers/models/olmoe/configuration_olmoe.py +29 -33
- transformers/models/olmoe/modeling_olmoe.py +41 -43
- transformers/models/olmoe/modular_olmoe.py +15 -16
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
- transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
- transformers/models/oneformer/configuration_oneformer.py +11 -51
- transformers/models/oneformer/image_processing_oneformer.py +83 -84
- transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
- transformers/models/oneformer/modeling_oneformer.py +137 -133
- transformers/models/oneformer/processing_oneformer.py +28 -43
- transformers/models/openai/configuration_openai.py +16 -1
- transformers/models/openai/modeling_openai.py +50 -51
- transformers/models/openai/tokenization_openai.py +2 -5
- transformers/models/opt/configuration_opt.py +6 -7
- transformers/models/opt/modeling_opt.py +79 -80
- transformers/models/ovis2/__init__.py +0 -1
- transformers/models/ovis2/configuration_ovis2.py +4 -1
- transformers/models/ovis2/image_processing_ovis2.py +22 -24
- transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
- transformers/models/ovis2/modeling_ovis2.py +99 -142
- transformers/models/ovis2/modular_ovis2.py +82 -45
- transformers/models/ovis2/processing_ovis2.py +12 -40
- transformers/models/owlv2/configuration_owlv2.py +4 -2
- transformers/models/owlv2/image_processing_owlv2.py +20 -21
- transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
- transformers/models/owlv2/modeling_owlv2.py +122 -114
- transformers/models/owlv2/modular_owlv2.py +11 -12
- transformers/models/owlv2/processing_owlv2.py +20 -49
- transformers/models/owlvit/configuration_owlvit.py +4 -2
- transformers/models/owlvit/image_processing_owlvit.py +21 -22
- transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
- transformers/models/owlvit/modeling_owlvit.py +121 -113
- transformers/models/owlvit/processing_owlvit.py +20 -48
- transformers/models/paddleocr_vl/__init__.py +0 -1
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
- transformers/models/paligemma/configuration_paligemma.py +4 -1
- transformers/models/paligemma/modeling_paligemma.py +81 -79
- transformers/models/paligemma/processing_paligemma.py +13 -66
- transformers/models/parakeet/configuration_parakeet.py +3 -8
- transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
- transformers/models/parakeet/modeling_parakeet.py +21 -25
- transformers/models/parakeet/modular_parakeet.py +19 -21
- transformers/models/parakeet/processing_parakeet.py +12 -5
- transformers/models/parakeet/tokenization_parakeet.py +2 -4
- transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
- transformers/models/patchtst/configuration_patchtst.py +6 -9
- transformers/models/patchtst/modeling_patchtst.py +75 -77
- transformers/models/pe_audio/__init__.py +0 -1
- transformers/models/pe_audio/configuration_pe_audio.py +14 -16
- transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
- transformers/models/pe_audio/modeling_pe_audio.py +30 -31
- transformers/models/pe_audio/modular_pe_audio.py +17 -18
- transformers/models/pe_audio/processing_pe_audio.py +0 -1
- transformers/models/pe_audio_video/__init__.py +0 -1
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
- transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
- transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
- transformers/models/pe_video/__init__.py +0 -1
- transformers/models/pe_video/configuration_pe_video.py +14 -16
- transformers/models/pe_video/modeling_pe_video.py +57 -46
- transformers/models/pe_video/modular_pe_video.py +47 -35
- transformers/models/pe_video/video_processing_pe_video.py +2 -4
- transformers/models/pegasus/configuration_pegasus.py +8 -6
- transformers/models/pegasus/modeling_pegasus.py +67 -69
- transformers/models/pegasus/tokenization_pegasus.py +1 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
- transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
- transformers/models/perceiver/configuration_perceiver.py +0 -1
- transformers/models/perceiver/image_processing_perceiver.py +22 -25
- transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
- transformers/models/perceiver/modeling_perceiver.py +152 -145
- transformers/models/perceiver/tokenization_perceiver.py +3 -6
- transformers/models/perception_lm/configuration_perception_lm.py +0 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
- transformers/models/perception_lm/modeling_perception_lm.py +64 -67
- transformers/models/perception_lm/modular_perception_lm.py +58 -58
- transformers/models/perception_lm/processing_perception_lm.py +13 -47
- transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
- transformers/models/persimmon/configuration_persimmon.py +23 -28
- transformers/models/persimmon/modeling_persimmon.py +44 -47
- transformers/models/phi/configuration_phi.py +27 -28
- transformers/models/phi/modeling_phi.py +39 -41
- transformers/models/phi/modular_phi.py +26 -26
- transformers/models/phi3/configuration_phi3.py +32 -37
- transformers/models/phi3/modeling_phi3.py +37 -40
- transformers/models/phi3/modular_phi3.py +16 -20
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
- transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
- transformers/models/phimoe/configuration_phimoe.py +31 -36
- transformers/models/phimoe/modeling_phimoe.py +50 -77
- transformers/models/phimoe/modular_phimoe.py +12 -8
- transformers/models/phobert/tokenization_phobert.py +4 -6
- transformers/models/pix2struct/configuration_pix2struct.py +12 -10
- transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
- transformers/models/pix2struct/modeling_pix2struct.py +56 -52
- transformers/models/pix2struct/processing_pix2struct.py +5 -26
- transformers/models/pixio/__init__.py +0 -1
- transformers/models/pixio/configuration_pixio.py +2 -5
- transformers/models/pixio/modeling_pixio.py +16 -17
- transformers/models/pixio/modular_pixio.py +7 -8
- transformers/models/pixtral/configuration_pixtral.py +11 -14
- transformers/models/pixtral/image_processing_pixtral.py +26 -28
- transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
- transformers/models/pixtral/modeling_pixtral.py +31 -37
- transformers/models/pixtral/processing_pixtral.py +18 -52
- transformers/models/plbart/configuration_plbart.py +8 -6
- transformers/models/plbart/modeling_plbart.py +109 -109
- transformers/models/plbart/modular_plbart.py +31 -33
- transformers/models/plbart/tokenization_plbart.py +4 -5
- transformers/models/poolformer/configuration_poolformer.py +0 -1
- transformers/models/poolformer/image_processing_poolformer.py +21 -24
- transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
- transformers/models/poolformer/modeling_poolformer.py +10 -12
- transformers/models/pop2piano/configuration_pop2piano.py +7 -7
- transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
- transformers/models/pop2piano/modeling_pop2piano.py +24 -24
- transformers/models/pop2piano/processing_pop2piano.py +25 -33
- transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
- transformers/models/prophetnet/configuration_prophetnet.py +37 -38
- transformers/models/prophetnet/modeling_prophetnet.py +121 -153
- transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
- transformers/models/pvt/configuration_pvt.py +0 -1
- transformers/models/pvt/image_processing_pvt.py +24 -27
- transformers/models/pvt/image_processing_pvt_fast.py +1 -2
- transformers/models/pvt/modeling_pvt.py +19 -21
- transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
- transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
- transformers/models/qwen2/configuration_qwen2.py +32 -25
- transformers/models/qwen2/modeling_qwen2.py +35 -37
- transformers/models/qwen2/modular_qwen2.py +14 -15
- transformers/models/qwen2/tokenization_qwen2.py +2 -9
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
- transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
- transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
- transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
- transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
- transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
- transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
- transformers/models/qwen3/configuration_qwen3.py +34 -27
- transformers/models/qwen3/modeling_qwen3.py +35 -38
- transformers/models/qwen3/modular_qwen3.py +7 -9
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
- transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
- transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
- transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
- transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
- transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
- transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
- transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
- transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
- transformers/models/rag/configuration_rag.py +6 -7
- transformers/models/rag/modeling_rag.py +119 -121
- transformers/models/rag/retrieval_rag.py +3 -5
- transformers/models/rag/tokenization_rag.py +0 -50
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
- transformers/models/reformer/configuration_reformer.py +7 -8
- transformers/models/reformer/modeling_reformer.py +67 -68
- transformers/models/reformer/tokenization_reformer.py +3 -6
- transformers/models/regnet/configuration_regnet.py +0 -1
- transformers/models/regnet/modeling_regnet.py +7 -9
- transformers/models/rembert/configuration_rembert.py +8 -2
- transformers/models/rembert/modeling_rembert.py +108 -132
- transformers/models/rembert/tokenization_rembert.py +1 -4
- transformers/models/resnet/configuration_resnet.py +2 -5
- transformers/models/resnet/modeling_resnet.py +14 -15
- transformers/models/roberta/configuration_roberta.py +11 -3
- transformers/models/roberta/modeling_roberta.py +97 -99
- transformers/models/roberta/modular_roberta.py +55 -58
- transformers/models/roberta/tokenization_roberta.py +2 -5
- transformers/models/roberta/tokenization_roberta_old.py +2 -4
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
- transformers/models/roc_bert/configuration_roc_bert.py +8 -2
- transformers/models/roc_bert/modeling_roc_bert.py +125 -162
- transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
- transformers/models/roformer/configuration_roformer.py +13 -3
- transformers/models/roformer/modeling_roformer.py +79 -95
- transformers/models/roformer/tokenization_roformer.py +3 -6
- transformers/models/roformer/tokenization_utils.py +0 -1
- transformers/models/rt_detr/configuration_rt_detr.py +8 -50
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
- transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
- transformers/models/rt_detr/modeling_rt_detr.py +643 -804
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
- transformers/models/rt_detr/modular_rt_detr.py +1522 -20
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
- transformers/models/rwkv/configuration_rwkv.py +2 -4
- transformers/models/rwkv/modeling_rwkv.py +29 -54
- transformers/models/sam/configuration_sam.py +2 -1
- transformers/models/sam/image_processing_sam.py +59 -60
- transformers/models/sam/image_processing_sam_fast.py +25 -26
- transformers/models/sam/modeling_sam.py +46 -43
- transformers/models/sam/processing_sam.py +39 -27
- transformers/models/sam2/configuration_sam2.py +1 -2
- transformers/models/sam2/image_processing_sam2_fast.py +14 -15
- transformers/models/sam2/modeling_sam2.py +96 -94
- transformers/models/sam2/modular_sam2.py +85 -94
- transformers/models/sam2/processing_sam2.py +31 -47
- transformers/models/sam2_video/configuration_sam2_video.py +0 -1
- transformers/models/sam2_video/modeling_sam2_video.py +114 -116
- transformers/models/sam2_video/modular_sam2_video.py +72 -89
- transformers/models/sam2_video/processing_sam2_video.py +49 -66
- transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
- transformers/models/sam3/configuration_sam3.py +0 -1
- transformers/models/sam3/image_processing_sam3_fast.py +17 -20
- transformers/models/sam3/modeling_sam3.py +94 -100
- transformers/models/sam3/modular_sam3.py +3 -8
- transformers/models/sam3/processing_sam3.py +37 -52
- transformers/models/sam3_tracker/__init__.py +0 -1
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
- transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
- transformers/models/sam3_tracker_video/__init__.py +0 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
- transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
- transformers/models/sam3_video/configuration_sam3_video.py +0 -1
- transformers/models/sam3_video/modeling_sam3_video.py +56 -45
- transformers/models/sam3_video/processing_sam3_video.py +25 -45
- transformers/models/sam_hq/__init__.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +2 -1
- transformers/models/sam_hq/modeling_sam_hq.py +52 -50
- transformers/models/sam_hq/modular_sam_hq.py +23 -25
- transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
- transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
- transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
- transformers/models/seed_oss/configuration_seed_oss.py +30 -34
- transformers/models/seed_oss/modeling_seed_oss.py +34 -36
- transformers/models/seed_oss/modular_seed_oss.py +6 -7
- transformers/models/segformer/configuration_segformer.py +0 -10
- transformers/models/segformer/image_processing_segformer.py +39 -42
- transformers/models/segformer/image_processing_segformer_fast.py +11 -12
- transformers/models/segformer/modeling_segformer.py +28 -28
- transformers/models/segformer/modular_segformer.py +8 -9
- transformers/models/seggpt/configuration_seggpt.py +0 -1
- transformers/models/seggpt/image_processing_seggpt.py +38 -41
- transformers/models/seggpt/modeling_seggpt.py +48 -38
- transformers/models/sew/configuration_sew.py +4 -2
- transformers/models/sew/modeling_sew.py +42 -40
- transformers/models/sew/modular_sew.py +12 -13
- transformers/models/sew_d/configuration_sew_d.py +4 -2
- transformers/models/sew_d/modeling_sew_d.py +32 -31
- transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
- transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
- transformers/models/siglip/configuration_siglip.py +4 -2
- transformers/models/siglip/image_processing_siglip.py +17 -20
- transformers/models/siglip/image_processing_siglip_fast.py +0 -1
- transformers/models/siglip/modeling_siglip.py +65 -110
- transformers/models/siglip/processing_siglip.py +2 -14
- transformers/models/siglip/tokenization_siglip.py +6 -7
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2.py +15 -16
- transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
- transformers/models/siglip2/modeling_siglip2.py +89 -130
- transformers/models/siglip2/modular_siglip2.py +95 -48
- transformers/models/siglip2/processing_siglip2.py +2 -14
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +29 -32
- transformers/models/smollm3/modeling_smollm3.py +35 -38
- transformers/models/smollm3/modular_smollm3.py +36 -38
- transformers/models/smolvlm/configuration_smolvlm.py +2 -4
- transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
- transformers/models/smolvlm/modeling_smolvlm.py +124 -96
- transformers/models/smolvlm/modular_smolvlm.py +50 -39
- transformers/models/smolvlm/processing_smolvlm.py +15 -76
- transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
- transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
- transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
- transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
- transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
- transformers/models/speecht5/configuration_speecht5.py +7 -9
- transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
- transformers/models/speecht5/modeling_speecht5.py +172 -174
- transformers/models/speecht5/number_normalizer.py +0 -1
- transformers/models/speecht5/processing_speecht5.py +3 -37
- transformers/models/speecht5/tokenization_speecht5.py +4 -5
- transformers/models/splinter/configuration_splinter.py +6 -7
- transformers/models/splinter/modeling_splinter.py +62 -59
- transformers/models/splinter/tokenization_splinter.py +2 -4
- transformers/models/squeezebert/configuration_squeezebert.py +14 -2
- transformers/models/squeezebert/modeling_squeezebert.py +60 -62
- transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
- transformers/models/stablelm/configuration_stablelm.py +28 -29
- transformers/models/stablelm/modeling_stablelm.py +44 -47
- transformers/models/starcoder2/configuration_starcoder2.py +30 -27
- transformers/models/starcoder2/modeling_starcoder2.py +38 -41
- transformers/models/starcoder2/modular_starcoder2.py +17 -19
- transformers/models/superglue/configuration_superglue.py +7 -3
- transformers/models/superglue/image_processing_superglue.py +15 -15
- transformers/models/superglue/image_processing_superglue_fast.py +8 -8
- transformers/models/superglue/modeling_superglue.py +41 -37
- transformers/models/superpoint/image_processing_superpoint.py +15 -15
- transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
- transformers/models/superpoint/modeling_superpoint.py +17 -16
- transformers/models/swiftformer/configuration_swiftformer.py +0 -1
- transformers/models/swiftformer/modeling_swiftformer.py +12 -14
- transformers/models/swin/configuration_swin.py +2 -5
- transformers/models/swin/modeling_swin.py +69 -78
- transformers/models/swin2sr/configuration_swin2sr.py +0 -1
- transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
- transformers/models/swin2sr/modeling_swin2sr.py +30 -30
- transformers/models/swinv2/configuration_swinv2.py +2 -5
- transformers/models/swinv2/modeling_swinv2.py +65 -74
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
- transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
- transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
- transformers/models/t5/configuration_t5.py +9 -9
- transformers/models/t5/modeling_t5.py +80 -85
- transformers/models/t5/tokenization_t5.py +1 -3
- transformers/models/t5gemma/configuration_t5gemma.py +43 -59
- transformers/models/t5gemma/modeling_t5gemma.py +105 -108
- transformers/models/t5gemma/modular_t5gemma.py +128 -142
- transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
- transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
- transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
- transformers/models/table_transformer/configuration_table_transformer.py +18 -50
- transformers/models/table_transformer/modeling_table_transformer.py +73 -101
- transformers/models/tapas/configuration_tapas.py +12 -2
- transformers/models/tapas/modeling_tapas.py +65 -67
- transformers/models/tapas/tokenization_tapas.py +116 -153
- transformers/models/textnet/configuration_textnet.py +4 -7
- transformers/models/textnet/image_processing_textnet.py +22 -25
- transformers/models/textnet/image_processing_textnet_fast.py +8 -9
- transformers/models/textnet/modeling_textnet.py +28 -28
- transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
- transformers/models/timesfm/configuration_timesfm.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +22 -25
- transformers/models/timesfm/modular_timesfm.py +21 -24
- transformers/models/timesformer/configuration_timesformer.py +0 -1
- transformers/models/timesformer/modeling_timesformer.py +13 -16
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
- transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
- transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
- transformers/models/trocr/configuration_trocr.py +11 -8
- transformers/models/trocr/modeling_trocr.py +42 -42
- transformers/models/trocr/processing_trocr.py +5 -25
- transformers/models/tvp/configuration_tvp.py +10 -36
- transformers/models/tvp/image_processing_tvp.py +50 -52
- transformers/models/tvp/image_processing_tvp_fast.py +15 -15
- transformers/models/tvp/modeling_tvp.py +26 -28
- transformers/models/tvp/processing_tvp.py +2 -14
- transformers/models/udop/configuration_udop.py +16 -8
- transformers/models/udop/modeling_udop.py +73 -72
- transformers/models/udop/processing_udop.py +7 -26
- transformers/models/udop/tokenization_udop.py +80 -93
- transformers/models/umt5/configuration_umt5.py +8 -7
- transformers/models/umt5/modeling_umt5.py +87 -84
- transformers/models/unispeech/configuration_unispeech.py +4 -2
- transformers/models/unispeech/modeling_unispeech.py +54 -53
- transformers/models/unispeech/modular_unispeech.py +20 -22
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
- transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
- transformers/models/univnet/feature_extraction_univnet.py +14 -14
- transformers/models/univnet/modeling_univnet.py +7 -8
- transformers/models/upernet/configuration_upernet.py +8 -36
- transformers/models/upernet/modeling_upernet.py +11 -14
- transformers/models/vaultgemma/__init__.py +0 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
- transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
- transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
- transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
- transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
- transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
- transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
- transformers/models/video_llava/configuration_video_llava.py +4 -1
- transformers/models/video_llava/image_processing_video_llava.py +35 -38
- transformers/models/video_llava/modeling_video_llava.py +139 -143
- transformers/models/video_llava/processing_video_llava.py +38 -78
- transformers/models/video_llava/video_processing_video_llava.py +0 -1
- transformers/models/videomae/configuration_videomae.py +0 -1
- transformers/models/videomae/image_processing_videomae.py +31 -34
- transformers/models/videomae/modeling_videomae.py +17 -20
- transformers/models/videomae/video_processing_videomae.py +0 -1
- transformers/models/vilt/configuration_vilt.py +4 -2
- transformers/models/vilt/image_processing_vilt.py +29 -30
- transformers/models/vilt/image_processing_vilt_fast.py +15 -16
- transformers/models/vilt/modeling_vilt.py +103 -90
- transformers/models/vilt/processing_vilt.py +2 -14
- transformers/models/vipllava/configuration_vipllava.py +4 -1
- transformers/models/vipllava/modeling_vipllava.py +92 -67
- transformers/models/vipllava/modular_vipllava.py +78 -54
- transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
- transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
- transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
- transformers/models/visual_bert/configuration_visual_bert.py +6 -2
- transformers/models/visual_bert/modeling_visual_bert.py +90 -92
- transformers/models/vit/configuration_vit.py +2 -3
- transformers/models/vit/image_processing_vit.py +19 -22
- transformers/models/vit/image_processing_vit_fast.py +0 -1
- transformers/models/vit/modeling_vit.py +20 -20
- transformers/models/vit_mae/configuration_vit_mae.py +0 -1
- transformers/models/vit_mae/modeling_vit_mae.py +32 -30
- transformers/models/vit_msn/configuration_vit_msn.py +0 -1
- transformers/models/vit_msn/modeling_vit_msn.py +21 -19
- transformers/models/vitdet/configuration_vitdet.py +2 -5
- transformers/models/vitdet/modeling_vitdet.py +14 -17
- transformers/models/vitmatte/configuration_vitmatte.py +7 -39
- transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
- transformers/models/vitmatte/modeling_vitmatte.py +10 -12
- transformers/models/vitpose/configuration_vitpose.py +7 -47
- transformers/models/vitpose/image_processing_vitpose.py +24 -25
- transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
- transformers/models/vitpose/modeling_vitpose.py +15 -15
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
- transformers/models/vits/configuration_vits.py +4 -1
- transformers/models/vits/modeling_vits.py +43 -42
- transformers/models/vits/tokenization_vits.py +3 -4
- transformers/models/vivit/configuration_vivit.py +0 -1
- transformers/models/vivit/image_processing_vivit.py +36 -39
- transformers/models/vivit/modeling_vivit.py +9 -11
- transformers/models/vjepa2/__init__.py +0 -1
- transformers/models/vjepa2/configuration_vjepa2.py +0 -1
- transformers/models/vjepa2/modeling_vjepa2.py +39 -41
- transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
- transformers/models/voxtral/__init__.py +0 -1
- transformers/models/voxtral/configuration_voxtral.py +0 -2
- transformers/models/voxtral/modeling_voxtral.py +41 -48
- transformers/models/voxtral/modular_voxtral.py +35 -38
- transformers/models/voxtral/processing_voxtral.py +25 -48
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
- transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
- transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
- transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
- transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
- transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
- transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
- transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
- transformers/models/wavlm/configuration_wavlm.py +4 -2
- transformers/models/wavlm/modeling_wavlm.py +49 -49
- transformers/models/wavlm/modular_wavlm.py +4 -5
- transformers/models/whisper/configuration_whisper.py +6 -5
- transformers/models/whisper/english_normalizer.py +3 -4
- transformers/models/whisper/feature_extraction_whisper.py +9 -24
- transformers/models/whisper/generation_whisper.py +26 -49
- transformers/models/whisper/modeling_whisper.py +71 -73
- transformers/models/whisper/processing_whisper.py +3 -20
- transformers/models/whisper/tokenization_whisper.py +9 -30
- transformers/models/x_clip/configuration_x_clip.py +4 -2
- transformers/models/x_clip/modeling_x_clip.py +94 -96
- transformers/models/x_clip/processing_x_clip.py +2 -14
- transformers/models/xcodec/configuration_xcodec.py +4 -6
- transformers/models/xcodec/modeling_xcodec.py +15 -17
- transformers/models/xglm/configuration_xglm.py +9 -8
- transformers/models/xglm/modeling_xglm.py +49 -55
- transformers/models/xglm/tokenization_xglm.py +1 -4
- transformers/models/xlm/configuration_xlm.py +10 -8
- transformers/models/xlm/modeling_xlm.py +127 -131
- transformers/models/xlm/tokenization_xlm.py +3 -5
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
- transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
- transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
- transformers/models/xlnet/configuration_xlnet.py +3 -12
- transformers/models/xlnet/modeling_xlnet.py +149 -162
- transformers/models/xlnet/tokenization_xlnet.py +1 -4
- transformers/models/xlstm/configuration_xlstm.py +8 -12
- transformers/models/xlstm/modeling_xlstm.py +61 -96
- transformers/models/xmod/configuration_xmod.py +11 -3
- transformers/models/xmod/modeling_xmod.py +111 -116
- transformers/models/yolos/configuration_yolos.py +0 -1
- transformers/models/yolos/image_processing_yolos.py +60 -62
- transformers/models/yolos/image_processing_yolos_fast.py +42 -45
- transformers/models/yolos/modeling_yolos.py +19 -21
- transformers/models/yolos/modular_yolos.py +17 -19
- transformers/models/yoso/configuration_yoso.py +8 -2
- transformers/models/yoso/modeling_yoso.py +60 -62
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -8
- transformers/models/zamba/modeling_zamba.py +93 -125
- transformers/models/zamba2/configuration_zamba2.py +44 -50
- transformers/models/zamba2/modeling_zamba2.py +137 -165
- transformers/models/zamba2/modular_zamba2.py +79 -74
- transformers/models/zoedepth/configuration_zoedepth.py +17 -41
- transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
- transformers/models/zoedepth/modeling_zoedepth.py +19 -19
- transformers/pipelines/__init__.py +47 -106
- transformers/pipelines/any_to_any.py +15 -23
- transformers/pipelines/audio_utils.py +1 -2
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +13 -17
- transformers/pipelines/image_text_to_text.py +1 -2
- transformers/pipelines/question_answering.py +4 -43
- transformers/pipelines/text_classification.py +1 -14
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/pipelines/token_classification.py +1 -22
- transformers/pipelines/video_classification.py +1 -9
- transformers/pipelines/zero_shot_audio_classification.py +0 -1
- transformers/pipelines/zero_shot_classification.py +0 -6
- transformers/pipelines/zero_shot_image_classification.py +0 -7
- transformers/processing_utils.py +128 -137
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/base.py +10 -0
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_quark.py +0 -1
- transformers/quantizers/quantizer_torchao.py +3 -19
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +6 -65
- transformers/tokenization_mistral_common.py +563 -903
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +228 -341
- transformers/tokenization_utils_sentencepiece.py +5 -6
- transformers/tokenization_utils_tokenizers.py +36 -7
- transformers/trainer.py +30 -41
- transformers/trainer_jit_checkpoint.py +1 -2
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +414 -420
- transformers/utils/__init__.py +1 -4
- transformers/utils/attention_visualizer.py +1 -1
- transformers/utils/auto_docstring.py +567 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/dummy_pt_objects.py +0 -42
- transformers/utils/generic.py +70 -34
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +135 -107
- transformers/utils/quantization_config.py +8 -31
- transformers/video_processing_utils.py +24 -25
- transformers/video_utils.py +21 -23
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
- transformers-5.1.0.dist-info/RECORD +2092 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -229
- transformers-5.0.0rc2.dist-info/RECORD +0 -2042
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# coding=utf-8
|
|
2
1
|
# Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved.
|
|
3
2
|
#
|
|
4
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -15,33 +14,35 @@
|
|
|
15
14
|
"""PyTorch DETR model."""
|
|
16
15
|
|
|
17
16
|
import math
|
|
17
|
+
from collections.abc import Callable
|
|
18
18
|
from dataclasses import dataclass
|
|
19
|
-
from typing import Optional, Union
|
|
20
19
|
|
|
21
20
|
import torch
|
|
22
|
-
|
|
21
|
+
import torch.nn as nn
|
|
23
22
|
|
|
24
23
|
from ... import initialization as init
|
|
25
24
|
from ...activations import ACT2FN
|
|
26
|
-
from ...
|
|
25
|
+
from ...backbone_utils import load_backbone
|
|
26
|
+
from ...masking_utils import create_bidirectional_mask
|
|
27
27
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
28
|
-
from ...modeling_outputs import
|
|
29
|
-
|
|
28
|
+
from ...modeling_outputs import (
|
|
29
|
+
BaseModelOutput,
|
|
30
|
+
BaseModelOutputWithCrossAttentions,
|
|
31
|
+
Seq2SeqModelOutput,
|
|
32
|
+
)
|
|
33
|
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
34
|
+
from ...processing_utils import Unpack
|
|
35
|
+
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
30
36
|
from ...utils import (
|
|
31
37
|
ModelOutput,
|
|
38
|
+
TransformersKwargs,
|
|
32
39
|
auto_docstring,
|
|
33
|
-
is_timm_available,
|
|
34
40
|
logging,
|
|
35
|
-
requires_backends,
|
|
36
41
|
)
|
|
37
|
-
from ...utils.
|
|
42
|
+
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
38
43
|
from .configuration_detr import DetrConfig
|
|
39
44
|
|
|
40
45
|
|
|
41
|
-
if is_timm_available():
|
|
42
|
-
from timm import create_model
|
|
43
|
-
|
|
44
|
-
|
|
45
46
|
logger = logging.get_logger(__name__)
|
|
46
47
|
|
|
47
48
|
|
|
@@ -64,7 +65,7 @@ class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
|
|
|
64
65
|
layernorm.
|
|
65
66
|
"""
|
|
66
67
|
|
|
67
|
-
intermediate_hidden_states:
|
|
68
|
+
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
68
69
|
|
|
69
70
|
|
|
70
71
|
@dataclass
|
|
@@ -84,7 +85,7 @@ class DetrModelOutput(Seq2SeqModelOutput):
|
|
|
84
85
|
layernorm.
|
|
85
86
|
"""
|
|
86
87
|
|
|
87
|
-
intermediate_hidden_states:
|
|
88
|
+
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
88
89
|
|
|
89
90
|
|
|
90
91
|
@dataclass
|
|
@@ -116,18 +117,18 @@ class DetrObjectDetectionOutput(ModelOutput):
|
|
|
116
117
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
|
117
118
|
"""
|
|
118
119
|
|
|
119
|
-
loss:
|
|
120
|
-
loss_dict:
|
|
121
|
-
logits:
|
|
122
|
-
pred_boxes:
|
|
123
|
-
auxiliary_outputs:
|
|
124
|
-
last_hidden_state:
|
|
125
|
-
decoder_hidden_states:
|
|
126
|
-
decoder_attentions:
|
|
127
|
-
cross_attentions:
|
|
128
|
-
encoder_last_hidden_state:
|
|
129
|
-
encoder_hidden_states:
|
|
130
|
-
encoder_attentions:
|
|
120
|
+
loss: torch.FloatTensor | None = None
|
|
121
|
+
loss_dict: dict | None = None
|
|
122
|
+
logits: torch.FloatTensor | None = None
|
|
123
|
+
pred_boxes: torch.FloatTensor | None = None
|
|
124
|
+
auxiliary_outputs: list[dict] | None = None
|
|
125
|
+
last_hidden_state: torch.FloatTensor | None = None
|
|
126
|
+
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
127
|
+
decoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
128
|
+
cross_attentions: tuple[torch.FloatTensor] | None = None
|
|
129
|
+
encoder_last_hidden_state: torch.FloatTensor | None = None
|
|
130
|
+
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
131
|
+
encoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
131
132
|
|
|
132
133
|
|
|
133
134
|
@dataclass
|
|
@@ -165,23 +166,21 @@ class DetrSegmentationOutput(ModelOutput):
|
|
|
165
166
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
|
166
167
|
"""
|
|
167
168
|
|
|
168
|
-
loss:
|
|
169
|
-
loss_dict:
|
|
170
|
-
logits:
|
|
171
|
-
pred_boxes:
|
|
172
|
-
pred_masks:
|
|
173
|
-
auxiliary_outputs:
|
|
174
|
-
last_hidden_state:
|
|
175
|
-
decoder_hidden_states:
|
|
176
|
-
decoder_attentions:
|
|
177
|
-
cross_attentions:
|
|
178
|
-
encoder_last_hidden_state:
|
|
179
|
-
encoder_hidden_states:
|
|
180
|
-
encoder_attentions:
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
# BELOW: utilities copied from
|
|
184
|
-
# https://github.com/facebookresearch/detr/blob/master/backbone.py
|
|
169
|
+
loss: torch.FloatTensor | None = None
|
|
170
|
+
loss_dict: dict | None = None
|
|
171
|
+
logits: torch.FloatTensor | None = None
|
|
172
|
+
pred_boxes: torch.FloatTensor | None = None
|
|
173
|
+
pred_masks: torch.FloatTensor | None = None
|
|
174
|
+
auxiliary_outputs: list[dict] | None = None
|
|
175
|
+
last_hidden_state: torch.FloatTensor | None = None
|
|
176
|
+
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
177
|
+
decoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
178
|
+
cross_attentions: tuple[torch.FloatTensor] | None = None
|
|
179
|
+
encoder_last_hidden_state: torch.FloatTensor | None = None
|
|
180
|
+
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
181
|
+
encoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
182
|
+
|
|
183
|
+
|
|
185
184
|
class DetrFrozenBatchNorm2d(nn.Module):
|
|
186
185
|
"""
|
|
187
186
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
@@ -258,47 +257,25 @@ class DetrConvEncoder(nn.Module):
|
|
|
258
257
|
|
|
259
258
|
self.config = config
|
|
260
259
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
# We default to values which were previously hard-coded. This enables configurability from the config
|
|
264
|
-
# using backbone arguments, while keeping the default behavior the same.
|
|
265
|
-
requires_backends(self, ["timm"])
|
|
266
|
-
kwargs = getattr(config, "backbone_kwargs", {})
|
|
267
|
-
kwargs = {} if kwargs is None else kwargs.copy()
|
|
268
|
-
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
|
|
269
|
-
num_channels = kwargs.pop("in_chans", config.num_channels)
|
|
270
|
-
if config.dilation:
|
|
271
|
-
kwargs["output_stride"] = kwargs.get("output_stride", 16)
|
|
272
|
-
backbone = create_model(
|
|
273
|
-
config.backbone,
|
|
274
|
-
pretrained=config.use_pretrained_backbone,
|
|
275
|
-
features_only=True,
|
|
276
|
-
out_indices=out_indices,
|
|
277
|
-
in_chans=num_channels,
|
|
278
|
-
**kwargs,
|
|
279
|
-
)
|
|
280
|
-
else:
|
|
281
|
-
backbone = load_backbone(config)
|
|
260
|
+
backbone = load_backbone(config)
|
|
261
|
+
self.intermediate_channel_sizes = backbone.channels
|
|
282
262
|
|
|
283
263
|
# replace batch norm by frozen batch norm
|
|
284
264
|
with torch.no_grad():
|
|
285
265
|
replace_batch_norm(backbone)
|
|
286
|
-
self.model = backbone
|
|
287
|
-
self.intermediate_channel_sizes = (
|
|
288
|
-
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
|
|
289
|
-
)
|
|
290
266
|
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
267
|
+
# We used to load with timm library directly instead of the AutoBackbone API
|
|
268
|
+
# so we need to unwrap the `backbone._backbone` module to load weights without mismatch
|
|
269
|
+
is_timm_model = False
|
|
270
|
+
if hasattr(backbone, "_backbone"):
|
|
271
|
+
backbone = backbone._backbone
|
|
272
|
+
is_timm_model = True
|
|
273
|
+
self.model = backbone
|
|
298
274
|
|
|
275
|
+
backbone_model_type = config.backbone_config.model_type
|
|
299
276
|
if "resnet" in backbone_model_type:
|
|
300
277
|
for name, parameter in self.model.named_parameters():
|
|
301
|
-
if
|
|
278
|
+
if is_timm_model:
|
|
302
279
|
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
|
303
280
|
parameter.requires_grad_(False)
|
|
304
281
|
else:
|
|
@@ -307,7 +284,9 @@ class DetrConvEncoder(nn.Module):
|
|
|
307
284
|
|
|
308
285
|
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
|
309
286
|
# send pixel_values through the model to get list of feature maps
|
|
310
|
-
features = self.model(pixel_values)
|
|
287
|
+
features = self.model(pixel_values)
|
|
288
|
+
if isinstance(features, dict):
|
|
289
|
+
features = features.feature_maps
|
|
311
290
|
|
|
312
291
|
out = []
|
|
313
292
|
for feature_map in features:
|
|
@@ -317,61 +296,55 @@ class DetrConvEncoder(nn.Module):
|
|
|
317
296
|
return out
|
|
318
297
|
|
|
319
298
|
|
|
320
|
-
class DetrConvModel(nn.Module):
|
|
321
|
-
"""
|
|
322
|
-
This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
|
|
323
|
-
"""
|
|
324
|
-
|
|
325
|
-
def __init__(self, conv_encoder, position_embedding):
|
|
326
|
-
super().__init__()
|
|
327
|
-
self.conv_encoder = conv_encoder
|
|
328
|
-
self.position_embedding = position_embedding
|
|
329
|
-
|
|
330
|
-
def forward(self, pixel_values, pixel_mask):
|
|
331
|
-
# send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
|
|
332
|
-
out = self.conv_encoder(pixel_values, pixel_mask)
|
|
333
|
-
pos = []
|
|
334
|
-
for feature_map, mask in out:
|
|
335
|
-
# position encoding
|
|
336
|
-
pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
|
|
337
|
-
|
|
338
|
-
return out, pos
|
|
339
|
-
|
|
340
|
-
|
|
341
299
|
class DetrSinePositionEmbedding(nn.Module):
|
|
342
300
|
"""
|
|
343
301
|
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
|
|
344
302
|
need paper, generalized to work on images.
|
|
345
303
|
"""
|
|
346
304
|
|
|
347
|
-
def __init__(
|
|
305
|
+
def __init__(
|
|
306
|
+
self,
|
|
307
|
+
num_position_features: int = 64,
|
|
308
|
+
temperature: int = 10000,
|
|
309
|
+
normalize: bool = False,
|
|
310
|
+
scale: float | None = None,
|
|
311
|
+
):
|
|
348
312
|
super().__init__()
|
|
349
|
-
self.embedding_dim = embedding_dim
|
|
350
|
-
self.temperature = temperature
|
|
351
|
-
self.normalize = normalize
|
|
352
313
|
if scale is not None and normalize is False:
|
|
353
314
|
raise ValueError("normalize should be True if scale is passed")
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
self.
|
|
315
|
+
self.num_position_features = num_position_features
|
|
316
|
+
self.temperature = temperature
|
|
317
|
+
self.normalize = normalize
|
|
318
|
+
self.scale = 2 * math.pi if scale is None else scale
|
|
357
319
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
320
|
+
@compile_compatible_method_lru_cache(maxsize=1)
|
|
321
|
+
def forward(
|
|
322
|
+
self,
|
|
323
|
+
shape: torch.Size,
|
|
324
|
+
device: torch.device | str,
|
|
325
|
+
dtype: torch.dtype,
|
|
326
|
+
mask: torch.Tensor | None = None,
|
|
327
|
+
) -> torch.Tensor:
|
|
328
|
+
if mask is None:
|
|
329
|
+
mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
|
|
330
|
+
y_embed = mask.cumsum(1, dtype=dtype)
|
|
331
|
+
x_embed = mask.cumsum(2, dtype=dtype)
|
|
363
332
|
if self.normalize:
|
|
364
|
-
|
|
365
|
-
|
|
333
|
+
eps = 1e-6
|
|
334
|
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
335
|
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
366
336
|
|
|
367
|
-
dim_t = torch.arange(self.
|
|
368
|
-
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.
|
|
337
|
+
dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
|
|
338
|
+
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
|
|
369
339
|
|
|
370
340
|
pos_x = x_embed[:, :, :, None] / dim_t
|
|
371
341
|
pos_y = y_embed[:, :, :, None] / dim_t
|
|
372
342
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
373
343
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
374
344
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
345
|
+
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
|
|
346
|
+
# expected by the encoder
|
|
347
|
+
pos = pos.flatten(2).permute(0, 2, 1)
|
|
375
348
|
return pos
|
|
376
349
|
|
|
377
350
|
|
|
@@ -385,207 +358,260 @@ class DetrLearnedPositionEmbedding(nn.Module):
|
|
|
385
358
|
self.row_embeddings = nn.Embedding(50, embedding_dim)
|
|
386
359
|
self.column_embeddings = nn.Embedding(50, embedding_dim)
|
|
387
360
|
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
361
|
+
@compile_compatible_method_lru_cache(maxsize=1)
|
|
362
|
+
def forward(
|
|
363
|
+
self,
|
|
364
|
+
shape: torch.Size,
|
|
365
|
+
device: torch.device | str,
|
|
366
|
+
dtype: torch.dtype,
|
|
367
|
+
mask: torch.Tensor | None = None,
|
|
368
|
+
):
|
|
369
|
+
height, width = shape[-2:]
|
|
370
|
+
width_values = torch.arange(width, device=device)
|
|
371
|
+
height_values = torch.arange(height, device=device)
|
|
392
372
|
x_emb = self.column_embeddings(width_values)
|
|
393
373
|
y_emb = self.row_embeddings(height_values)
|
|
394
374
|
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
|
|
395
375
|
pos = pos.permute(2, 0, 1)
|
|
396
376
|
pos = pos.unsqueeze(0)
|
|
397
|
-
pos = pos.repeat(
|
|
377
|
+
pos = pos.repeat(shape[0], 1, 1, 1)
|
|
378
|
+
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
|
|
379
|
+
# expected by the encoder
|
|
380
|
+
pos = pos.flatten(2).permute(0, 2, 1)
|
|
398
381
|
return pos
|
|
399
382
|
|
|
400
383
|
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
384
|
+
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
|
385
|
+
def eager_attention_forward(
|
|
386
|
+
module: nn.Module,
|
|
387
|
+
query: torch.Tensor,
|
|
388
|
+
key: torch.Tensor,
|
|
389
|
+
value: torch.Tensor,
|
|
390
|
+
attention_mask: torch.Tensor | None,
|
|
391
|
+
scaling: float | None = None,
|
|
392
|
+
dropout: float = 0.0,
|
|
393
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
394
|
+
):
|
|
395
|
+
if scaling is None:
|
|
396
|
+
scaling = query.size(-1) ** -0.5
|
|
397
|
+
|
|
398
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
399
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
410
400
|
|
|
411
|
-
|
|
401
|
+
if attention_mask is not None:
|
|
402
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
403
|
+
attn_weights = attn_weights + attention_mask
|
|
412
404
|
|
|
405
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
406
|
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
413
407
|
|
|
414
|
-
|
|
408
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
409
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
410
|
+
|
|
411
|
+
return attn_output, attn_weights
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
class DetrSelfAttention(nn.Module):
|
|
415
415
|
"""
|
|
416
|
-
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
416
|
+
Multi-headed self-attention from 'Attention Is All You Need' paper.
|
|
417
417
|
|
|
418
|
-
|
|
418
|
+
In DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
|
|
419
419
|
"""
|
|
420
420
|
|
|
421
421
|
def __init__(
|
|
422
422
|
self,
|
|
423
|
-
|
|
424
|
-
|
|
423
|
+
config: DetrConfig,
|
|
424
|
+
hidden_size: int,
|
|
425
|
+
num_attention_heads: int,
|
|
425
426
|
dropout: float = 0.0,
|
|
426
427
|
bias: bool = True,
|
|
427
428
|
):
|
|
428
429
|
super().__init__()
|
|
429
|
-
self.
|
|
430
|
-
self.
|
|
431
|
-
self.dropout = dropout
|
|
432
|
-
self.head_dim = embed_dim // num_heads
|
|
433
|
-
if self.head_dim * num_heads != self.embed_dim:
|
|
434
|
-
raise ValueError(
|
|
435
|
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
436
|
-
f" {num_heads})."
|
|
437
|
-
)
|
|
430
|
+
self.config = config
|
|
431
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
438
432
|
self.scaling = self.head_dim**-0.5
|
|
433
|
+
self.attention_dropout = dropout
|
|
434
|
+
self.is_causal = False
|
|
439
435
|
|
|
440
|
-
self.k_proj = nn.Linear(
|
|
441
|
-
self.v_proj = nn.Linear(
|
|
442
|
-
self.q_proj = nn.Linear(
|
|
443
|
-
self.
|
|
444
|
-
|
|
445
|
-
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
|
446
|
-
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
447
|
-
|
|
448
|
-
def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]):
|
|
449
|
-
return tensor if object_queries is None else tensor + object_queries
|
|
436
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
437
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
438
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
439
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
450
440
|
|
|
451
441
|
def forward(
|
|
452
442
|
self,
|
|
453
443
|
hidden_states: torch.Tensor,
|
|
454
|
-
attention_mask:
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
"""
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
is_cross_attention = key_value_states is not None
|
|
464
|
-
batch_size, target_len, embed_dim = hidden_states.size()
|
|
465
|
-
|
|
466
|
-
# add position embeddings to the hidden states before projecting to queries and keys
|
|
467
|
-
if object_queries is not None:
|
|
468
|
-
hidden_states_original = hidden_states
|
|
469
|
-
hidden_states = self.with_pos_embed(hidden_states, object_queries)
|
|
470
|
-
|
|
471
|
-
# add key-value position embeddings to the key value states
|
|
472
|
-
if spatial_position_embeddings is not None:
|
|
473
|
-
key_value_states_original = key_value_states
|
|
474
|
-
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
|
|
475
|
-
|
|
476
|
-
# get query proj
|
|
477
|
-
query_states = self.q_proj(hidden_states) * self.scaling
|
|
478
|
-
# get key, value proj
|
|
479
|
-
if is_cross_attention:
|
|
480
|
-
# cross_attentions
|
|
481
|
-
key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
|
|
482
|
-
value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
|
|
483
|
-
else:
|
|
484
|
-
# self_attention
|
|
485
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
|
|
486
|
-
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
|
|
444
|
+
attention_mask: torch.Tensor | None = None,
|
|
445
|
+
position_embeddings: torch.Tensor | None = None,
|
|
446
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
447
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
448
|
+
"""
|
|
449
|
+
Position embeddings are added to both queries and keys (but not values).
|
|
450
|
+
"""
|
|
451
|
+
input_shape = hidden_states.shape[:-1]
|
|
452
|
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
487
453
|
|
|
488
|
-
|
|
489
|
-
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
|
|
490
|
-
key_states = key_states.view(*proj_shape)
|
|
491
|
-
value_states = value_states.view(*proj_shape)
|
|
454
|
+
query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
|
|
492
455
|
|
|
493
|
-
|
|
456
|
+
query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
457
|
+
key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
458
|
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
494
459
|
|
|
495
|
-
|
|
460
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
461
|
+
self.config._attn_implementation, eager_attention_forward
|
|
462
|
+
)
|
|
496
463
|
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
464
|
+
attn_output, attn_weights = attention_interface(
|
|
465
|
+
self,
|
|
466
|
+
query_states,
|
|
467
|
+
key_states,
|
|
468
|
+
value_states,
|
|
469
|
+
attention_mask,
|
|
470
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
471
|
+
scaling=self.scaling,
|
|
472
|
+
**kwargs,
|
|
473
|
+
)
|
|
502
474
|
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
|
507
|
-
f" {attention_mask.size()}"
|
|
508
|
-
)
|
|
509
|
-
if attention_mask.dtype == torch.bool:
|
|
510
|
-
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
|
|
511
|
-
attention_mask, -torch.inf
|
|
512
|
-
)
|
|
513
|
-
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
|
514
|
-
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
|
515
|
-
|
|
516
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
517
|
-
|
|
518
|
-
if output_attentions:
|
|
519
|
-
# this operation is a bit awkward, but it's required to
|
|
520
|
-
# make sure that attn_weights keeps its gradient.
|
|
521
|
-
# In order to do so, attn_weights have to reshaped
|
|
522
|
-
# twice and have to be reused in the following
|
|
523
|
-
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
|
524
|
-
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
|
525
|
-
else:
|
|
526
|
-
attn_weights_reshaped = None
|
|
475
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
476
|
+
attn_output = self.o_proj(attn_output)
|
|
477
|
+
return attn_output, attn_weights
|
|
527
478
|
|
|
528
|
-
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
529
479
|
|
|
530
|
-
|
|
480
|
+
class DetrCrossAttention(nn.Module):
|
|
481
|
+
"""
|
|
482
|
+
Multi-headed cross-attention from 'Attention Is All You Need' paper.
|
|
531
483
|
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
f" {attn_output.size()}"
|
|
536
|
-
)
|
|
484
|
+
In DETR, queries get their own position embeddings, while keys get encoder position embeddings.
|
|
485
|
+
Values don't get any position embeddings.
|
|
486
|
+
"""
|
|
537
487
|
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
488
|
+
def __init__(
|
|
489
|
+
self,
|
|
490
|
+
config: DetrConfig,
|
|
491
|
+
hidden_size: int,
|
|
492
|
+
num_attention_heads: int,
|
|
493
|
+
dropout: float = 0.0,
|
|
494
|
+
bias: bool = True,
|
|
495
|
+
):
|
|
496
|
+
super().__init__()
|
|
497
|
+
self.config = config
|
|
498
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
499
|
+
self.scaling = self.head_dim**-0.5
|
|
500
|
+
self.attention_dropout = dropout
|
|
501
|
+
self.is_causal = False
|
|
541
502
|
|
|
542
|
-
|
|
503
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
504
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
505
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
506
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
543
507
|
|
|
544
|
-
|
|
508
|
+
def forward(
|
|
509
|
+
self,
|
|
510
|
+
hidden_states: torch.Tensor,
|
|
511
|
+
key_value_states: torch.Tensor,
|
|
512
|
+
attention_mask: torch.Tensor | None = None,
|
|
513
|
+
position_embeddings: torch.Tensor | None = None,
|
|
514
|
+
encoder_position_embeddings: torch.Tensor | None = None,
|
|
515
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
516
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
517
|
+
"""
|
|
518
|
+
Position embeddings logic:
|
|
519
|
+
- Queries get position_embeddings
|
|
520
|
+
- Keys get encoder_position_embeddings
|
|
521
|
+
- Values don't get any position embeddings
|
|
522
|
+
"""
|
|
523
|
+
query_input_shape = hidden_states.shape[:-1]
|
|
524
|
+
query_hidden_shape = (*query_input_shape, -1, self.head_dim)
|
|
525
|
+
|
|
526
|
+
kv_input_shape = key_value_states.shape[:-1]
|
|
527
|
+
kv_hidden_shape = (*kv_input_shape, -1, self.head_dim)
|
|
528
|
+
|
|
529
|
+
query_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
|
|
530
|
+
key_input = (
|
|
531
|
+
key_value_states + encoder_position_embeddings
|
|
532
|
+
if encoder_position_embeddings is not None
|
|
533
|
+
else key_value_states
|
|
534
|
+
)
|
|
545
535
|
|
|
536
|
+
query_states = self.q_proj(query_input).view(query_hidden_shape).transpose(1, 2)
|
|
537
|
+
key_states = self.k_proj(key_input).view(kv_hidden_shape).transpose(1, 2)
|
|
538
|
+
value_states = self.v_proj(key_value_states).view(kv_hidden_shape).transpose(1, 2)
|
|
546
539
|
|
|
547
|
-
|
|
540
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
541
|
+
self.config._attn_implementation, eager_attention_forward
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
attn_output, attn_weights = attention_interface(
|
|
545
|
+
self,
|
|
546
|
+
query_states,
|
|
547
|
+
key_states,
|
|
548
|
+
value_states,
|
|
549
|
+
attention_mask,
|
|
550
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
551
|
+
scaling=self.scaling,
|
|
552
|
+
**kwargs,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
|
|
556
|
+
attn_output = self.o_proj(attn_output)
|
|
557
|
+
return attn_output, attn_weights
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
class DetrMLP(nn.Module):
|
|
561
|
+
def __init__(self, config: DetrConfig, hidden_size: int, intermediate_size: int):
|
|
562
|
+
super().__init__()
|
|
563
|
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
|
564
|
+
self.fc2 = nn.Linear(intermediate_size, hidden_size)
|
|
565
|
+
self.activation_fn = ACT2FN[config.activation_function]
|
|
566
|
+
self.activation_dropout = config.activation_dropout
|
|
567
|
+
self.dropout = config.dropout
|
|
568
|
+
|
|
569
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
570
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
571
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
572
|
+
hidden_states = self.fc2(hidden_states)
|
|
573
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
574
|
+
return hidden_states
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
class DetrEncoderLayer(GradientCheckpointingLayer):
|
|
548
578
|
def __init__(self, config: DetrConfig):
|
|
549
579
|
super().__init__()
|
|
550
|
-
self.
|
|
551
|
-
self.self_attn =
|
|
552
|
-
|
|
553
|
-
|
|
580
|
+
self.hidden_size = config.d_model
|
|
581
|
+
self.self_attn = DetrSelfAttention(
|
|
582
|
+
config=config,
|
|
583
|
+
hidden_size=self.hidden_size,
|
|
584
|
+
num_attention_heads=config.encoder_attention_heads,
|
|
554
585
|
dropout=config.attention_dropout,
|
|
555
586
|
)
|
|
556
|
-
self.self_attn_layer_norm = nn.LayerNorm(self.
|
|
587
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
557
588
|
self.dropout = config.dropout
|
|
558
|
-
self.
|
|
559
|
-
self.
|
|
560
|
-
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
|
561
|
-
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
|
562
|
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
589
|
+
self.mlp = DetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
|
|
590
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
563
591
|
|
|
564
592
|
def forward(
|
|
565
593
|
self,
|
|
566
594
|
hidden_states: torch.Tensor,
|
|
567
595
|
attention_mask: torch.Tensor,
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
):
|
|
596
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
597
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
598
|
+
) -> torch.Tensor:
|
|
571
599
|
"""
|
|
572
600
|
Args:
|
|
573
|
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len,
|
|
601
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
574
602
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
575
603
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
576
604
|
values.
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
581
|
-
returned tensors for more detail.
|
|
605
|
+
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
606
|
+
Spatial position embeddings (2D positional encodings of image locations), to be added to both
|
|
607
|
+
the queries and keys in self-attention (but not to values).
|
|
582
608
|
"""
|
|
583
609
|
residual = hidden_states
|
|
584
|
-
hidden_states,
|
|
610
|
+
hidden_states, _ = self.self_attn(
|
|
585
611
|
hidden_states=hidden_states,
|
|
586
612
|
attention_mask=attention_mask,
|
|
587
|
-
|
|
588
|
-
|
|
613
|
+
position_embeddings=spatial_position_embeddings,
|
|
614
|
+
**kwargs,
|
|
589
615
|
)
|
|
590
616
|
|
|
591
617
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
@@ -593,12 +619,7 @@ class DetrEncoderLayer(nn.Module):
|
|
|
593
619
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
594
620
|
|
|
595
621
|
residual = hidden_states
|
|
596
|
-
hidden_states = self.
|
|
597
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
598
|
-
|
|
599
|
-
hidden_states = self.fc2(hidden_states)
|
|
600
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
601
|
-
|
|
622
|
+
hidden_states = self.mlp(hidden_states)
|
|
602
623
|
hidden_states = residual + hidden_states
|
|
603
624
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
604
625
|
|
|
@@ -607,78 +628,69 @@ class DetrEncoderLayer(nn.Module):
|
|
|
607
628
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
608
629
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
609
630
|
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
if output_attentions:
|
|
613
|
-
outputs += (attn_weights,)
|
|
614
|
-
|
|
615
|
-
return outputs
|
|
631
|
+
return hidden_states
|
|
616
632
|
|
|
617
633
|
|
|
618
634
|
class DetrDecoderLayer(GradientCheckpointingLayer):
|
|
619
635
|
def __init__(self, config: DetrConfig):
|
|
620
636
|
super().__init__()
|
|
621
|
-
self.
|
|
637
|
+
self.hidden_size = config.d_model
|
|
622
638
|
|
|
623
|
-
self.self_attn =
|
|
624
|
-
|
|
625
|
-
|
|
639
|
+
self.self_attn = DetrSelfAttention(
|
|
640
|
+
config=config,
|
|
641
|
+
hidden_size=self.hidden_size,
|
|
642
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
626
643
|
dropout=config.attention_dropout,
|
|
627
644
|
)
|
|
628
645
|
self.dropout = config.dropout
|
|
629
|
-
self.activation_fn = ACT2FN[config.activation_function]
|
|
630
|
-
self.activation_dropout = config.activation_dropout
|
|
631
646
|
|
|
632
|
-
self.self_attn_layer_norm = nn.LayerNorm(self.
|
|
633
|
-
self.encoder_attn =
|
|
634
|
-
|
|
635
|
-
|
|
647
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
648
|
+
self.encoder_attn = DetrCrossAttention(
|
|
649
|
+
config=config,
|
|
650
|
+
hidden_size=self.hidden_size,
|
|
651
|
+
num_attention_heads=config.decoder_attention_heads,
|
|
636
652
|
dropout=config.attention_dropout,
|
|
637
653
|
)
|
|
638
|
-
self.encoder_attn_layer_norm = nn.LayerNorm(self.
|
|
639
|
-
self.
|
|
640
|
-
self.
|
|
641
|
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
654
|
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
655
|
+
self.mlp = DetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
|
|
656
|
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
642
657
|
|
|
643
658
|
def forward(
|
|
644
659
|
self,
|
|
645
660
|
hidden_states: torch.Tensor,
|
|
646
|
-
attention_mask:
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
encoder_hidden_states:
|
|
650
|
-
encoder_attention_mask:
|
|
651
|
-
|
|
652
|
-
):
|
|
661
|
+
attention_mask: torch.Tensor | None = None,
|
|
662
|
+
spatial_position_embeddings: torch.Tensor | None = None,
|
|
663
|
+
object_queries_position_embeddings: torch.Tensor | None = None,
|
|
664
|
+
encoder_hidden_states: torch.Tensor | None = None,
|
|
665
|
+
encoder_attention_mask: torch.Tensor | None = None,
|
|
666
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
667
|
+
) -> torch.Tensor:
|
|
653
668
|
"""
|
|
654
669
|
Args:
|
|
655
|
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len,
|
|
670
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
656
671
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
657
672
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
658
673
|
values.
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
674
|
+
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
675
|
+
Spatial position embeddings (2D positional encodings from encoder) that are added to the keys only
|
|
676
|
+
in the cross-attention layer (not to values).
|
|
677
|
+
object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
678
|
+
Position embeddings for the object query slots. In self-attention, these are added to both queries
|
|
679
|
+
and keys (not values). In cross-attention, these are added to queries only (not to keys or values).
|
|
665
680
|
encoder_hidden_states (`torch.FloatTensor`):
|
|
666
|
-
cross attention input to the layer of shape `(batch, seq_len,
|
|
681
|
+
cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
667
682
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
668
683
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
669
684
|
values.
|
|
670
|
-
output_attentions (`bool`, *optional*):
|
|
671
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
672
|
-
returned tensors for more detail.
|
|
673
685
|
"""
|
|
674
686
|
residual = hidden_states
|
|
675
687
|
|
|
676
688
|
# Self Attention
|
|
677
|
-
hidden_states,
|
|
689
|
+
hidden_states, _ = self.self_attn(
|
|
678
690
|
hidden_states=hidden_states,
|
|
679
|
-
|
|
691
|
+
position_embeddings=object_queries_position_embeddings,
|
|
680
692
|
attention_mask=attention_mask,
|
|
681
|
-
|
|
693
|
+
**kwargs,
|
|
682
694
|
)
|
|
683
695
|
|
|
684
696
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
@@ -686,17 +698,16 @@ class DetrDecoderLayer(GradientCheckpointingLayer):
|
|
|
686
698
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
687
699
|
|
|
688
700
|
# Cross-Attention Block
|
|
689
|
-
cross_attn_weights = None
|
|
690
701
|
if encoder_hidden_states is not None:
|
|
691
702
|
residual = hidden_states
|
|
692
703
|
|
|
693
|
-
hidden_states,
|
|
704
|
+
hidden_states, _ = self.encoder_attn(
|
|
694
705
|
hidden_states=hidden_states,
|
|
695
|
-
object_queries=query_position_embeddings,
|
|
696
706
|
key_value_states=encoder_hidden_states,
|
|
697
707
|
attention_mask=encoder_attention_mask,
|
|
698
|
-
|
|
699
|
-
|
|
708
|
+
position_embeddings=object_queries_position_embeddings,
|
|
709
|
+
encoder_position_embeddings=spatial_position_embeddings,
|
|
710
|
+
**kwargs,
|
|
700
711
|
)
|
|
701
712
|
|
|
702
713
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
@@ -705,19 +716,164 @@ class DetrDecoderLayer(GradientCheckpointingLayer):
|
|
|
705
716
|
|
|
706
717
|
# Fully Connected
|
|
707
718
|
residual = hidden_states
|
|
708
|
-
hidden_states = self.
|
|
709
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
710
|
-
hidden_states = self.fc2(hidden_states)
|
|
711
|
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
719
|
+
hidden_states = self.mlp(hidden_states)
|
|
712
720
|
hidden_states = residual + hidden_states
|
|
713
721
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
714
722
|
|
|
715
|
-
|
|
723
|
+
return hidden_states
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
class DetrConvBlock(nn.Module):
|
|
727
|
+
"""Basic conv block: Conv3x3 -> GroupNorm -> Activation."""
|
|
728
|
+
|
|
729
|
+
def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"):
|
|
730
|
+
super().__init__()
|
|
731
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
732
|
+
self.norm = nn.GroupNorm(min(8, out_channels), out_channels)
|
|
733
|
+
self.activation = ACT2FN[activation]
|
|
734
|
+
|
|
735
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
736
|
+
return self.activation(self.norm(self.conv(x)))
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
class DetrFPNFusionStage(nn.Module):
|
|
740
|
+
"""Single FPN fusion stage combining low-resolution features with high-resolution FPN features."""
|
|
741
|
+
|
|
742
|
+
def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"):
|
|
743
|
+
super().__init__()
|
|
744
|
+
self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1)
|
|
745
|
+
self.refine = DetrConvBlock(current_channels, output_channels, activation)
|
|
746
|
+
|
|
747
|
+
def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor:
|
|
748
|
+
"""
|
|
749
|
+
Args:
|
|
750
|
+
features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in)
|
|
751
|
+
fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out)
|
|
752
|
+
|
|
753
|
+
Returns:
|
|
754
|
+
Fused and refined features, shape (B*Q, output_channels, H_out, W_out)
|
|
755
|
+
"""
|
|
756
|
+
fpn_features = self.fpn_adapter(fpn_features)
|
|
757
|
+
features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest")
|
|
758
|
+
return self.refine(fpn_features + features)
|
|
759
|
+
|
|
716
760
|
|
|
717
|
-
|
|
718
|
-
|
|
761
|
+
class DetrMaskHeadSmallConv(nn.Module):
|
|
762
|
+
"""
|
|
763
|
+
Segmentation mask head that generates per-query masks using FPN-based progressive upsampling.
|
|
719
764
|
|
|
720
|
-
|
|
765
|
+
Combines attention maps (spatial localization) with encoder features (semantics) and progressively
|
|
766
|
+
upsamples through multiple scales, fusing with FPN features for high-resolution detail.
|
|
767
|
+
"""
|
|
768
|
+
|
|
769
|
+
def __init__(
|
|
770
|
+
self,
|
|
771
|
+
input_channels: int,
|
|
772
|
+
fpn_channels: list[int],
|
|
773
|
+
hidden_size: int,
|
|
774
|
+
activation_function: str = "relu",
|
|
775
|
+
):
|
|
776
|
+
super().__init__()
|
|
777
|
+
if input_channels % 8 != 0:
|
|
778
|
+
raise ValueError(f"input_channels must be divisible by 8, got {input_channels}")
|
|
779
|
+
|
|
780
|
+
self.conv1 = DetrConvBlock(input_channels, input_channels, activation_function)
|
|
781
|
+
self.conv2 = DetrConvBlock(input_channels, hidden_size // 2, activation_function)
|
|
782
|
+
|
|
783
|
+
# Progressive channel reduction: /2 -> /4 -> /8 -> /16
|
|
784
|
+
self.fpn_stages = nn.ModuleList(
|
|
785
|
+
[
|
|
786
|
+
DetrFPNFusionStage(fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function),
|
|
787
|
+
DetrFPNFusionStage(fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function),
|
|
788
|
+
DetrFPNFusionStage(fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function),
|
|
789
|
+
]
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1)
|
|
793
|
+
|
|
794
|
+
def forward(
|
|
795
|
+
self,
|
|
796
|
+
features: torch.Tensor,
|
|
797
|
+
attention_masks: torch.Tensor,
|
|
798
|
+
fpn_features: list[torch.Tensor],
|
|
799
|
+
) -> torch.Tensor:
|
|
800
|
+
"""
|
|
801
|
+
Args:
|
|
802
|
+
features: Encoder output features, shape (batch_size, hidden_size, H, W)
|
|
803
|
+
attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W)
|
|
804
|
+
fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W)
|
|
805
|
+
|
|
806
|
+
Returns:
|
|
807
|
+
Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W)
|
|
808
|
+
"""
|
|
809
|
+
num_queries = attention_masks.shape[1]
|
|
810
|
+
|
|
811
|
+
# Expand to (batch_size * num_queries) dimension
|
|
812
|
+
features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1)
|
|
813
|
+
attention_masks = attention_masks.flatten(0, 1)
|
|
814
|
+
fpn_features = [
|
|
815
|
+
fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features
|
|
816
|
+
]
|
|
817
|
+
|
|
818
|
+
hidden_states = torch.cat([features, attention_masks], dim=1)
|
|
819
|
+
hidden_states = self.conv1(hidden_states)
|
|
820
|
+
hidden_states = self.conv2(hidden_states)
|
|
821
|
+
|
|
822
|
+
for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features):
|
|
823
|
+
hidden_states = fpn_stage(hidden_states, fpn_feat)
|
|
824
|
+
|
|
825
|
+
return self.output_conv(hidden_states)
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
class DetrMHAttentionMap(nn.Module):
|
|
829
|
+
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
|
830
|
+
|
|
831
|
+
def __init__(
|
|
832
|
+
self,
|
|
833
|
+
hidden_size: int,
|
|
834
|
+
num_attention_heads: int,
|
|
835
|
+
dropout: float = 0.0,
|
|
836
|
+
bias: bool = True,
|
|
837
|
+
):
|
|
838
|
+
super().__init__()
|
|
839
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
840
|
+
self.scaling = self.head_dim**-0.5
|
|
841
|
+
self.attention_dropout = dropout
|
|
842
|
+
|
|
843
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
844
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
845
|
+
|
|
846
|
+
def forward(
|
|
847
|
+
self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None
|
|
848
|
+
):
|
|
849
|
+
query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim)
|
|
850
|
+
key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:])
|
|
851
|
+
|
|
852
|
+
query_states = self.q_proj(query_states).view(query_hidden_shape)
|
|
853
|
+
key_states = nn.functional.conv2d(
|
|
854
|
+
key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias
|
|
855
|
+
).view(key_hidden_shape)
|
|
856
|
+
|
|
857
|
+
batch_size, num_queries, num_heads, head_dim = query_states.shape
|
|
858
|
+
_, _, _, height, width = key_states.shape
|
|
859
|
+
query_shape = (batch_size * num_heads, num_queries, head_dim)
|
|
860
|
+
key_shape = (batch_size * num_heads, height * width, head_dim)
|
|
861
|
+
attn_weights_shape = (batch_size, num_heads, num_queries, height, width)
|
|
862
|
+
|
|
863
|
+
query = query_states.transpose(1, 2).contiguous().view(query_shape)
|
|
864
|
+
key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape)
|
|
865
|
+
|
|
866
|
+
attn_weights = (
|
|
867
|
+
(torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2)
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
if attention_mask is not None:
|
|
871
|
+
attn_weights = attn_weights + attention_mask
|
|
872
|
+
|
|
873
|
+
attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size())
|
|
874
|
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
875
|
+
|
|
876
|
+
return attn_weights
|
|
721
877
|
|
|
722
878
|
|
|
723
879
|
@auto_docstring
|
|
@@ -727,21 +883,36 @@ class DetrPreTrainedModel(PreTrainedModel):
|
|
|
727
883
|
main_input_name = "pixel_values"
|
|
728
884
|
input_modalities = ("image",)
|
|
729
885
|
_no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
|
|
886
|
+
supports_gradient_checkpointing = True
|
|
887
|
+
_supports_sdpa = True
|
|
888
|
+
_supports_flash_attn = True
|
|
889
|
+
_supports_attention_backend = True
|
|
890
|
+
_supports_flex_attn = True # Uses create_bidirectional_masks for attention masking
|
|
891
|
+
_keys_to_ignore_on_load_unexpected = [
|
|
892
|
+
r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
|
|
893
|
+
]
|
|
730
894
|
|
|
731
895
|
@torch.no_grad()
|
|
732
896
|
def _init_weights(self, module):
|
|
733
897
|
std = self.config.init_std
|
|
734
898
|
xavier_std = self.config.init_xavier_std
|
|
735
899
|
|
|
736
|
-
if isinstance(module,
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
900
|
+
if isinstance(module, DetrMaskHeadSmallConv):
|
|
901
|
+
# DetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers
|
|
902
|
+
for m in module.modules():
|
|
903
|
+
if isinstance(m, nn.Conv2d):
|
|
904
|
+
init.kaiming_uniform_(m.weight, a=1)
|
|
905
|
+
if m.bias is not None:
|
|
906
|
+
init.constant_(m.bias, 0)
|
|
907
|
+
elif isinstance(module, DetrMHAttentionMap):
|
|
908
|
+
init.zeros_(module.k_proj.bias)
|
|
909
|
+
init.zeros_(module.q_proj.bias)
|
|
910
|
+
init.xavier_uniform_(module.k_proj.weight, gain=xavier_std)
|
|
911
|
+
init.xavier_uniform_(module.q_proj.weight, gain=xavier_std)
|
|
741
912
|
elif isinstance(module, DetrLearnedPositionEmbedding):
|
|
742
913
|
init.uniform_(module.row_embeddings.weight)
|
|
743
914
|
init.uniform_(module.column_embeddings.weight)
|
|
744
|
-
|
|
915
|
+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
745
916
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
746
917
|
if module.bias is not None:
|
|
747
918
|
init.zeros_(module.bias)
|
|
@@ -757,47 +928,36 @@ class DetrPreTrainedModel(PreTrainedModel):
|
|
|
757
928
|
|
|
758
929
|
class DetrEncoder(DetrPreTrainedModel):
|
|
759
930
|
"""
|
|
760
|
-
Transformer encoder
|
|
761
|
-
[`DetrEncoderLayer`].
|
|
762
|
-
|
|
763
|
-
The encoder updates the flattened feature map through multiple self-attention layers.
|
|
764
|
-
|
|
765
|
-
Small tweak for DETR:
|
|
766
|
-
|
|
767
|
-
- object_queries are added to the forward pass.
|
|
931
|
+
Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of
|
|
932
|
+
[`DetrEncoderLayer`] modules.
|
|
768
933
|
|
|
769
934
|
Args:
|
|
770
|
-
config:
|
|
935
|
+
config (`DetrConfig`): Model configuration object.
|
|
771
936
|
"""
|
|
772
937
|
|
|
938
|
+
_can_record_outputs = {"hidden_states": DetrEncoderLayer, "attentions": DetrSelfAttention}
|
|
939
|
+
|
|
773
940
|
def __init__(self, config: DetrConfig):
|
|
774
941
|
super().__init__(config)
|
|
775
942
|
|
|
776
943
|
self.dropout = config.dropout
|
|
777
|
-
self.layerdrop = config.encoder_layerdrop
|
|
778
|
-
|
|
779
944
|
self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
780
945
|
|
|
781
|
-
# in the original DETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
|
|
782
|
-
|
|
783
946
|
# Initialize weights and apply final processing
|
|
784
947
|
self.post_init()
|
|
785
948
|
|
|
949
|
+
@check_model_inputs()
|
|
786
950
|
def forward(
|
|
787
951
|
self,
|
|
788
952
|
inputs_embeds=None,
|
|
789
953
|
attention_mask=None,
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
return_dict=None,
|
|
794
|
-
**kwargs,
|
|
795
|
-
):
|
|
954
|
+
spatial_position_embeddings=None,
|
|
955
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
956
|
+
) -> BaseModelOutput:
|
|
796
957
|
r"""
|
|
797
958
|
Args:
|
|
798
959
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
799
960
|
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
|
800
|
-
|
|
801
961
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
802
962
|
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
|
803
963
|
|
|
@@ -805,112 +965,67 @@ class DetrEncoder(DetrPreTrainedModel):
|
|
|
805
965
|
- 0 for pixel features that are padding (i.e. **masked**).
|
|
806
966
|
|
|
807
967
|
[What are attention masks?](../glossary#attention-mask)
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
Object queries that are added to the queries in each self-attention layer.
|
|
811
|
-
|
|
812
|
-
output_attentions (`bool`, *optional*):
|
|
813
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
814
|
-
returned tensors for more detail.
|
|
815
|
-
output_hidden_states (`bool`, *optional*):
|
|
816
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
817
|
-
for more detail.
|
|
818
|
-
return_dict (`bool`, *optional*):
|
|
819
|
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
968
|
+
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
969
|
+
Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
|
|
820
970
|
"""
|
|
821
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
822
|
-
output_hidden_states = (
|
|
823
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
824
|
-
)
|
|
825
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
826
|
-
|
|
827
971
|
hidden_states = inputs_embeds
|
|
828
972
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
829
973
|
|
|
830
974
|
# expand attention_mask
|
|
831
975
|
if attention_mask is not None:
|
|
832
976
|
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
833
|
-
attention_mask =
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
if output_hidden_states:
|
|
839
|
-
encoder_states = encoder_states + (hidden_states,)
|
|
840
|
-
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
|
841
|
-
to_drop = False
|
|
842
|
-
if self.training:
|
|
843
|
-
dropout_probability = torch.rand([])
|
|
844
|
-
if dropout_probability < self.layerdrop: # skip the layer
|
|
845
|
-
to_drop = True
|
|
846
|
-
|
|
847
|
-
if to_drop:
|
|
848
|
-
layer_outputs = (None, None)
|
|
849
|
-
else:
|
|
850
|
-
# we add object_queries as extra input to the encoder_layer
|
|
851
|
-
layer_outputs = encoder_layer(
|
|
852
|
-
hidden_states,
|
|
853
|
-
attention_mask,
|
|
854
|
-
object_queries=object_queries,
|
|
855
|
-
output_attentions=output_attentions,
|
|
856
|
-
)
|
|
857
|
-
|
|
858
|
-
hidden_states = layer_outputs[0]
|
|
859
|
-
|
|
860
|
-
if output_attentions:
|
|
861
|
-
all_attentions = all_attentions + (layer_outputs[1],)
|
|
862
|
-
|
|
863
|
-
if output_hidden_states:
|
|
864
|
-
encoder_states = encoder_states + (hidden_states,)
|
|
865
|
-
|
|
866
|
-
if not return_dict:
|
|
867
|
-
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
|
868
|
-
return BaseModelOutput(
|
|
869
|
-
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
|
870
|
-
)
|
|
871
|
-
|
|
977
|
+
attention_mask = create_bidirectional_mask(
|
|
978
|
+
config=self.config,
|
|
979
|
+
input_embeds=inputs_embeds,
|
|
980
|
+
attention_mask=attention_mask,
|
|
981
|
+
)
|
|
872
982
|
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
983
|
+
for encoder_layer in self.layers:
|
|
984
|
+
# we add spatial_position_embeddings as extra input to the encoder_layer
|
|
985
|
+
hidden_states = encoder_layer(
|
|
986
|
+
hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs
|
|
987
|
+
)
|
|
876
988
|
|
|
877
|
-
|
|
989
|
+
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
878
990
|
|
|
879
|
-
Some small tweaks for DETR:
|
|
880
991
|
|
|
881
|
-
|
|
882
|
-
|
|
992
|
+
class DetrDecoder(DetrPreTrainedModel):
|
|
993
|
+
"""
|
|
994
|
+
Transformer decoder that refines a set of object queries. It is composed of a stack of [`DetrDecoderLayer`] modules,
|
|
995
|
+
which apply self-attention to the queries and cross-attention to the encoder's outputs.
|
|
883
996
|
|
|
884
997
|
Args:
|
|
885
|
-
config:
|
|
998
|
+
config (`DetrConfig`): Model configuration object.
|
|
886
999
|
"""
|
|
887
1000
|
|
|
1001
|
+
_can_record_outputs = {
|
|
1002
|
+
"hidden_states": DetrDecoderLayer,
|
|
1003
|
+
"attentions": DetrSelfAttention,
|
|
1004
|
+
"cross_attentions": DetrCrossAttention,
|
|
1005
|
+
}
|
|
1006
|
+
|
|
888
1007
|
def __init__(self, config: DetrConfig):
|
|
889
1008
|
super().__init__(config)
|
|
890
1009
|
self.dropout = config.dropout
|
|
891
|
-
self.layerdrop = config.decoder_layerdrop
|
|
892
1010
|
|
|
893
1011
|
self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
|
|
894
1012
|
# in DETR, the decoder uses layernorm after the last decoder layer output
|
|
895
1013
|
self.layernorm = nn.LayerNorm(config.d_model)
|
|
896
1014
|
|
|
897
|
-
self.gradient_checkpointing = False
|
|
898
1015
|
# Initialize weights and apply final processing
|
|
899
1016
|
self.post_init()
|
|
900
1017
|
|
|
1018
|
+
@check_model_inputs()
|
|
901
1019
|
def forward(
|
|
902
1020
|
self,
|
|
903
1021
|
inputs_embeds=None,
|
|
904
1022
|
attention_mask=None,
|
|
905
1023
|
encoder_hidden_states=None,
|
|
906
1024
|
encoder_attention_mask=None,
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
return_dict=None,
|
|
912
|
-
**kwargs,
|
|
913
|
-
):
|
|
1025
|
+
spatial_position_embeddings=None,
|
|
1026
|
+
object_queries_position_embeddings=None,
|
|
1027
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1028
|
+
) -> DetrDecoderOutput:
|
|
914
1029
|
r"""
|
|
915
1030
|
Args:
|
|
916
1031
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
@@ -933,108 +1048,62 @@ class DetrDecoder(DetrPreTrainedModel):
|
|
|
933
1048
|
- 1 for pixels that are real (i.e. **not masked**),
|
|
934
1049
|
- 0 for pixels that are padding (i.e. **masked**).
|
|
935
1050
|
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
output_attentions (`bool`, *optional*):
|
|
942
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
943
|
-
returned tensors for more detail.
|
|
944
|
-
output_hidden_states (`bool`, *optional*):
|
|
945
|
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
946
|
-
for more detail.
|
|
947
|
-
return_dict (`bool`, *optional*):
|
|
948
|
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
1051
|
+
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1052
|
+
Spatial position embeddings (2D positional encodings from encoder) that are added to the keys in each cross-attention layer.
|
|
1053
|
+
object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1054
|
+
Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer.
|
|
949
1055
|
"""
|
|
950
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
951
|
-
output_hidden_states = (
|
|
952
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
953
|
-
)
|
|
954
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
955
1056
|
|
|
956
1057
|
if inputs_embeds is not None:
|
|
957
1058
|
hidden_states = inputs_embeds
|
|
958
|
-
input_shape = inputs_embeds.size()[:-1]
|
|
959
|
-
|
|
960
|
-
combined_attention_mask = None
|
|
961
1059
|
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
1060
|
+
# expand decoder attention mask (for self-attention on object queries)
|
|
1061
|
+
if attention_mask is not None:
|
|
1062
|
+
# [batch_size, num_queries] -> [batch_size, 1, num_queries, num_queries]
|
|
1063
|
+
attention_mask = create_bidirectional_mask(
|
|
1064
|
+
config=self.config,
|
|
1065
|
+
input_embeds=inputs_embeds,
|
|
1066
|
+
attention_mask=attention_mask,
|
|
966
1067
|
)
|
|
967
1068
|
|
|
968
|
-
# expand encoder attention mask
|
|
1069
|
+
# expand encoder attention mask (for cross-attention on encoder outputs)
|
|
969
1070
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
|
970
1071
|
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
971
|
-
encoder_attention_mask =
|
|
972
|
-
|
|
1072
|
+
encoder_attention_mask = create_bidirectional_mask(
|
|
1073
|
+
config=self.config,
|
|
1074
|
+
input_embeds=inputs_embeds,
|
|
1075
|
+
attention_mask=encoder_attention_mask,
|
|
1076
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
973
1077
|
)
|
|
974
1078
|
|
|
975
1079
|
# optional intermediate hidden states
|
|
976
1080
|
intermediate = () if self.config.auxiliary_loss else None
|
|
977
1081
|
|
|
978
1082
|
# decoder layers
|
|
979
|
-
all_hidden_states = () if output_hidden_states else None
|
|
980
|
-
all_self_attns = () if output_attentions else None
|
|
981
|
-
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
982
1083
|
|
|
983
1084
|
for idx, decoder_layer in enumerate(self.layers):
|
|
984
|
-
|
|
985
|
-
if output_hidden_states:
|
|
986
|
-
all_hidden_states += (hidden_states,)
|
|
987
|
-
if self.training:
|
|
988
|
-
dropout_probability = torch.rand([])
|
|
989
|
-
if dropout_probability < self.layerdrop:
|
|
990
|
-
continue
|
|
991
|
-
|
|
992
|
-
layer_outputs = decoder_layer(
|
|
1085
|
+
hidden_states = decoder_layer(
|
|
993
1086
|
hidden_states,
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
1087
|
+
attention_mask,
|
|
1088
|
+
spatial_position_embeddings,
|
|
1089
|
+
object_queries_position_embeddings,
|
|
997
1090
|
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
|
998
1091
|
encoder_attention_mask=encoder_attention_mask,
|
|
999
|
-
|
|
1092
|
+
**kwargs,
|
|
1000
1093
|
)
|
|
1001
1094
|
|
|
1002
|
-
hidden_states = layer_outputs[0]
|
|
1003
|
-
|
|
1004
1095
|
if self.config.auxiliary_loss:
|
|
1005
1096
|
hidden_states = self.layernorm(hidden_states)
|
|
1006
1097
|
intermediate += (hidden_states,)
|
|
1007
1098
|
|
|
1008
|
-
if output_attentions:
|
|
1009
|
-
all_self_attns += (layer_outputs[1],)
|
|
1010
|
-
|
|
1011
|
-
if encoder_hidden_states is not None:
|
|
1012
|
-
all_cross_attentions += (layer_outputs[2],)
|
|
1013
|
-
|
|
1014
1099
|
# finally, apply layernorm
|
|
1015
1100
|
hidden_states = self.layernorm(hidden_states)
|
|
1016
1101
|
|
|
1017
|
-
# add hidden states from the last decoder layer
|
|
1018
|
-
if output_hidden_states:
|
|
1019
|
-
all_hidden_states += (hidden_states,)
|
|
1020
|
-
|
|
1021
1102
|
# stack intermediate decoder activations
|
|
1022
1103
|
if self.config.auxiliary_loss:
|
|
1023
1104
|
intermediate = torch.stack(intermediate)
|
|
1024
1105
|
|
|
1025
|
-
|
|
1026
|
-
return tuple(
|
|
1027
|
-
v
|
|
1028
|
-
for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate]
|
|
1029
|
-
if v is not None
|
|
1030
|
-
)
|
|
1031
|
-
return DetrDecoderOutput(
|
|
1032
|
-
last_hidden_state=hidden_states,
|
|
1033
|
-
hidden_states=all_hidden_states,
|
|
1034
|
-
attentions=all_self_attns,
|
|
1035
|
-
cross_attentions=all_cross_attentions,
|
|
1036
|
-
intermediate_hidden_states=intermediate,
|
|
1037
|
-
)
|
|
1106
|
+
return DetrDecoderOutput(last_hidden_state=hidden_states, intermediate_hidden_states=intermediate)
|
|
1038
1107
|
|
|
1039
1108
|
|
|
1040
1109
|
@auto_docstring(
|
|
@@ -1047,15 +1116,16 @@ class DetrModel(DetrPreTrainedModel):
|
|
|
1047
1116
|
def __init__(self, config: DetrConfig):
|
|
1048
1117
|
super().__init__(config)
|
|
1049
1118
|
|
|
1050
|
-
|
|
1051
|
-
backbone = DetrConvEncoder(config)
|
|
1052
|
-
object_queries = build_position_encoding(config)
|
|
1053
|
-
self.backbone = DetrConvModel(backbone, object_queries)
|
|
1054
|
-
|
|
1055
|
-
# Create projection layer
|
|
1056
|
-
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
|
1119
|
+
self.backbone = DetrConvEncoder(config)
|
|
1057
1120
|
|
|
1121
|
+
if config.position_embedding_type == "sine":
|
|
1122
|
+
self.position_embedding = DetrSinePositionEmbedding(config.d_model // 2, normalize=True)
|
|
1123
|
+
elif config.position_embedding_type == "learned":
|
|
1124
|
+
self.position_embedding = DetrLearnedPositionEmbedding(config.d_model // 2)
|
|
1125
|
+
else:
|
|
1126
|
+
raise ValueError(f"Not supported {config.position_embedding_type}")
|
|
1058
1127
|
self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
|
|
1128
|
+
self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
|
1059
1129
|
|
|
1060
1130
|
self.encoder = DetrEncoder(config)
|
|
1061
1131
|
self.decoder = DetrDecoder(config)
|
|
@@ -1064,46 +1134,49 @@ class DetrModel(DetrPreTrainedModel):
|
|
|
1064
1134
|
self.post_init()
|
|
1065
1135
|
|
|
1066
1136
|
def freeze_backbone(self):
|
|
1067
|
-
for
|
|
1137
|
+
for _, param in self.backbone.model.named_parameters():
|
|
1068
1138
|
param.requires_grad_(False)
|
|
1069
1139
|
|
|
1070
1140
|
def unfreeze_backbone(self):
|
|
1071
|
-
for
|
|
1141
|
+
for _, param in self.backbone.model.named_parameters():
|
|
1072
1142
|
param.requires_grad_(True)
|
|
1073
1143
|
|
|
1074
1144
|
@auto_docstring
|
|
1145
|
+
@can_return_tuple
|
|
1075
1146
|
def forward(
|
|
1076
1147
|
self,
|
|
1077
|
-
pixel_values: torch.FloatTensor,
|
|
1078
|
-
pixel_mask:
|
|
1079
|
-
decoder_attention_mask:
|
|
1080
|
-
encoder_outputs:
|
|
1081
|
-
inputs_embeds:
|
|
1082
|
-
decoder_inputs_embeds:
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
return_dict: Optional[bool] = None,
|
|
1086
|
-
**kwargs,
|
|
1087
|
-
) -> Union[tuple[torch.FloatTensor], DetrModelOutput]:
|
|
1148
|
+
pixel_values: torch.FloatTensor | None = None,
|
|
1149
|
+
pixel_mask: torch.LongTensor | None = None,
|
|
1150
|
+
decoder_attention_mask: torch.FloatTensor | None = None,
|
|
1151
|
+
encoder_outputs: torch.FloatTensor | None = None,
|
|
1152
|
+
inputs_embeds: torch.FloatTensor | None = None,
|
|
1153
|
+
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1154
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1155
|
+
) -> tuple[torch.FloatTensor] | DetrModelOutput:
|
|
1088
1156
|
r"""
|
|
1089
1157
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
1090
|
-
|
|
1158
|
+
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
|
|
1159
|
+
|
|
1160
|
+
- 1 for queries that are **not masked**,
|
|
1161
|
+
- 0 for queries that are **masked**.
|
|
1091
1162
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1092
1163
|
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
1093
|
-
can choose to directly pass a flattened representation of an image.
|
|
1164
|
+
can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
|
|
1094
1165
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1095
1166
|
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1096
|
-
embedded representation.
|
|
1167
|
+
embedded representation. Useful for tasks that require custom query initialization.
|
|
1097
1168
|
|
|
1098
1169
|
Examples:
|
|
1099
1170
|
|
|
1100
1171
|
```python
|
|
1101
1172
|
>>> from transformers import AutoImageProcessor, DetrModel
|
|
1102
1173
|
>>> from PIL import Image
|
|
1103
|
-
>>> import
|
|
1174
|
+
>>> import httpx
|
|
1175
|
+
>>> from io import BytesIO
|
|
1104
1176
|
|
|
1105
1177
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
1106
|
-
>>>
|
|
1178
|
+
>>> with httpx.stream("GET", url) as response:
|
|
1179
|
+
... image = Image.open(BytesIO(response.read()))
|
|
1107
1180
|
|
|
1108
1181
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
|
1109
1182
|
>>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
|
|
@@ -1120,79 +1193,77 @@ class DetrModel(DetrPreTrainedModel):
|
|
|
1120
1193
|
>>> list(last_hidden_states.shape)
|
|
1121
1194
|
[1, 100, 256]
|
|
1122
1195
|
```"""
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1196
|
+
if pixel_values is None and inputs_embeds is None:
|
|
1197
|
+
raise ValueError("You have to specify either pixel_values or inputs_embeds")
|
|
1198
|
+
|
|
1199
|
+
if inputs_embeds is None:
|
|
1200
|
+
batch_size, num_channels, height, width = pixel_values.shape
|
|
1201
|
+
device = pixel_values.device
|
|
1202
|
+
|
|
1203
|
+
if pixel_mask is None:
|
|
1204
|
+
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
|
1205
|
+
vision_features = self.backbone(pixel_values, pixel_mask)
|
|
1206
|
+
feature_map, mask = vision_features[-1]
|
|
1207
|
+
|
|
1208
|
+
# Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
|
|
1209
|
+
# Position embeddings are already flattened to (batch_size, sequence_length, hidden_size) format
|
|
1210
|
+
projected_feature_map = self.input_projection(feature_map)
|
|
1211
|
+
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
|
1212
|
+
spatial_position_embeddings = self.position_embedding(
|
|
1213
|
+
shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
|
|
1214
|
+
)
|
|
1215
|
+
flattened_mask = mask.flatten(1)
|
|
1216
|
+
else:
|
|
1217
|
+
batch_size = inputs_embeds.shape[0]
|
|
1218
|
+
device = inputs_embeds.device
|
|
1219
|
+
flattened_features = inputs_embeds
|
|
1220
|
+
# When using inputs_embeds, we need to infer spatial dimensions for position embeddings
|
|
1221
|
+
# Assume square feature map
|
|
1222
|
+
seq_len = inputs_embeds.shape[1]
|
|
1223
|
+
feat_dim = int(seq_len**0.5)
|
|
1224
|
+
# Create position embeddings for the inferred spatial size
|
|
1225
|
+
spatial_position_embeddings = self.position_embedding(
|
|
1226
|
+
shape=torch.Size([batch_size, self.config.d_model, feat_dim, feat_dim]),
|
|
1227
|
+
device=device,
|
|
1228
|
+
dtype=inputs_embeds.dtype,
|
|
1229
|
+
)
|
|
1230
|
+
# If a pixel_mask is provided with inputs_embeds, interpolate it to feat_dim, then flatten.
|
|
1231
|
+
if pixel_mask is not None:
|
|
1232
|
+
mask = nn.functional.interpolate(pixel_mask[None].float(), size=(feat_dim, feat_dim)).to(torch.bool)[0]
|
|
1233
|
+
flattened_mask = mask.flatten(1)
|
|
1234
|
+
else:
|
|
1235
|
+
# If no mask provided, assume all positions are valid
|
|
1236
|
+
flattened_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long)
|
|
1155
1237
|
|
|
1156
|
-
# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
|
|
1157
|
-
# flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
|
|
1158
|
-
# flattened_mask is a Tensor of shape (batch_size, height*width)
|
|
1159
1238
|
if encoder_outputs is None:
|
|
1160
1239
|
encoder_outputs = self.encoder(
|
|
1161
1240
|
inputs_embeds=flattened_features,
|
|
1162
1241
|
attention_mask=flattened_mask,
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
output_hidden_states=output_hidden_states,
|
|
1166
|
-
return_dict=return_dict,
|
|
1167
|
-
)
|
|
1168
|
-
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
|
1169
|
-
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
|
1170
|
-
encoder_outputs = BaseModelOutput(
|
|
1171
|
-
last_hidden_state=encoder_outputs[0],
|
|
1172
|
-
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1173
|
-
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1242
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1243
|
+
**kwargs,
|
|
1174
1244
|
)
|
|
1175
1245
|
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1246
|
+
object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(
|
|
1247
|
+
batch_size, 1, 1
|
|
1248
|
+
)
|
|
1249
|
+
|
|
1250
|
+
# Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
|
|
1251
|
+
if decoder_inputs_embeds is not None:
|
|
1252
|
+
queries = decoder_inputs_embeds
|
|
1253
|
+
else:
|
|
1254
|
+
queries = torch.zeros_like(object_queries_position_embeddings)
|
|
1179
1255
|
|
|
1180
1256
|
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
|
|
1181
1257
|
decoder_outputs = self.decoder(
|
|
1182
1258
|
inputs_embeds=queries,
|
|
1183
|
-
attention_mask=
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
encoder_hidden_states=encoder_outputs
|
|
1259
|
+
attention_mask=decoder_attention_mask,
|
|
1260
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1261
|
+
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
1262
|
+
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
1187
1263
|
encoder_attention_mask=flattened_mask,
|
|
1188
|
-
|
|
1189
|
-
output_hidden_states=output_hidden_states,
|
|
1190
|
-
return_dict=return_dict,
|
|
1264
|
+
**kwargs,
|
|
1191
1265
|
)
|
|
1192
1266
|
|
|
1193
|
-
if not return_dict:
|
|
1194
|
-
return decoder_outputs + encoder_outputs
|
|
1195
|
-
|
|
1196
1267
|
return DetrModelOutput(
|
|
1197
1268
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
1198
1269
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
@@ -1205,14 +1276,11 @@ class DetrModel(DetrPreTrainedModel):
|
|
|
1205
1276
|
)
|
|
1206
1277
|
|
|
1207
1278
|
|
|
1208
|
-
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
1209
1279
|
class DetrMLPPredictionHead(nn.Module):
|
|
1210
1280
|
"""
|
|
1211
1281
|
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
|
1212
1282
|
height and width of a bounding box w.r.t. an image.
|
|
1213
1283
|
|
|
1214
|
-
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
|
1215
|
-
|
|
1216
1284
|
"""
|
|
1217
1285
|
|
|
1218
1286
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
@@ -1252,29 +1320,30 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1252
1320
|
self.post_init()
|
|
1253
1321
|
|
|
1254
1322
|
@auto_docstring
|
|
1323
|
+
@can_return_tuple
|
|
1255
1324
|
def forward(
|
|
1256
1325
|
self,
|
|
1257
1326
|
pixel_values: torch.FloatTensor,
|
|
1258
|
-
pixel_mask:
|
|
1259
|
-
decoder_attention_mask:
|
|
1260
|
-
encoder_outputs:
|
|
1261
|
-
inputs_embeds:
|
|
1262
|
-
decoder_inputs_embeds:
|
|
1263
|
-
labels:
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
return_dict: Optional[bool] = None,
|
|
1267
|
-
**kwargs,
|
|
1268
|
-
) -> Union[tuple[torch.FloatTensor], DetrObjectDetectionOutput]:
|
|
1327
|
+
pixel_mask: torch.LongTensor | None = None,
|
|
1328
|
+
decoder_attention_mask: torch.FloatTensor | None = None,
|
|
1329
|
+
encoder_outputs: torch.FloatTensor | None = None,
|
|
1330
|
+
inputs_embeds: torch.FloatTensor | None = None,
|
|
1331
|
+
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1332
|
+
labels: list[dict] | None = None,
|
|
1333
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1334
|
+
) -> tuple[torch.FloatTensor] | DetrObjectDetectionOutput:
|
|
1269
1335
|
r"""
|
|
1270
1336
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
1271
|
-
|
|
1337
|
+
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
|
|
1338
|
+
|
|
1339
|
+
- 1 for queries that are **not masked**,
|
|
1340
|
+
- 0 for queries that are **masked**.
|
|
1272
1341
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1273
1342
|
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
1274
|
-
can choose to directly pass a flattened representation of an image.
|
|
1343
|
+
can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
|
|
1275
1344
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1276
1345
|
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1277
|
-
embedded representation.
|
|
1346
|
+
embedded representation. Useful for tasks that require custom query initialization.
|
|
1278
1347
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1279
1348
|
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
|
1280
1349
|
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
|
@@ -1287,10 +1356,12 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1287
1356
|
>>> from transformers import AutoImageProcessor, DetrForObjectDetection
|
|
1288
1357
|
>>> import torch
|
|
1289
1358
|
>>> from PIL import Image
|
|
1290
|
-
>>> import
|
|
1359
|
+
>>> import httpx
|
|
1360
|
+
>>> from io import BytesIO
|
|
1291
1361
|
|
|
1292
1362
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
1293
|
-
>>>
|
|
1363
|
+
>>> with httpx.stream("GET", url) as response:
|
|
1364
|
+
... image = Image.open(BytesIO(response.read()))
|
|
1294
1365
|
|
|
1295
1366
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
|
1296
1367
|
>>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
|
@@ -1316,7 +1387,6 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1316
1387
|
Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
|
|
1317
1388
|
Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
|
|
1318
1389
|
```"""
|
|
1319
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1320
1390
|
|
|
1321
1391
|
# First, sent images through DETR base model to obtain encoder + decoder outputs
|
|
1322
1392
|
outputs = self.model(
|
|
@@ -1326,9 +1396,7 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1326
1396
|
encoder_outputs=encoder_outputs,
|
|
1327
1397
|
inputs_embeds=inputs_embeds,
|
|
1328
1398
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
1329
|
-
|
|
1330
|
-
output_hidden_states=output_hidden_states,
|
|
1331
|
-
return_dict=return_dict,
|
|
1399
|
+
**kwargs,
|
|
1332
1400
|
)
|
|
1333
1401
|
|
|
1334
1402
|
sequence_output = outputs[0]
|
|
@@ -1341,20 +1409,13 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1341
1409
|
if labels is not None:
|
|
1342
1410
|
outputs_class, outputs_coord = None, None
|
|
1343
1411
|
if self.config.auxiliary_loss:
|
|
1344
|
-
intermediate = outputs.intermediate_hidden_states
|
|
1412
|
+
intermediate = outputs.intermediate_hidden_states
|
|
1345
1413
|
outputs_class = self.class_labels_classifier(intermediate)
|
|
1346
1414
|
outputs_coord = self.bbox_predictor(intermediate).sigmoid()
|
|
1347
1415
|
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
|
1348
1416
|
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
|
1349
1417
|
)
|
|
1350
1418
|
|
|
1351
|
-
if not return_dict:
|
|
1352
|
-
if auxiliary_outputs is not None:
|
|
1353
|
-
output = (logits, pred_boxes) + auxiliary_outputs + outputs
|
|
1354
|
-
else:
|
|
1355
|
-
output = (logits, pred_boxes) + outputs
|
|
1356
|
-
return ((loss, loss_dict) + output) if loss is not None else output
|
|
1357
|
-
|
|
1358
1419
|
return DetrObjectDetectionOutput(
|
|
1359
1420
|
loss=loss,
|
|
1360
1421
|
loss_dict=loss_dict,
|
|
@@ -1378,6 +1439,26 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|
|
1378
1439
|
"""
|
|
1379
1440
|
)
|
|
1380
1441
|
class DetrForSegmentation(DetrPreTrainedModel):
|
|
1442
|
+
_checkpoint_conversion_mapping = {
|
|
1443
|
+
"bbox_attention.q_linear": "bbox_attention.q_proj",
|
|
1444
|
+
"bbox_attention.k_linear": "bbox_attention.k_proj",
|
|
1445
|
+
# Mask head refactor
|
|
1446
|
+
"mask_head.lay1": "mask_head.conv1.conv",
|
|
1447
|
+
"mask_head.gn1": "mask_head.conv1.norm",
|
|
1448
|
+
"mask_head.lay2": "mask_head.conv2.conv",
|
|
1449
|
+
"mask_head.gn2": "mask_head.conv2.norm",
|
|
1450
|
+
"mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter",
|
|
1451
|
+
"mask_head.lay3": "mask_head.fpn_stages.0.refine.conv",
|
|
1452
|
+
"mask_head.gn3": "mask_head.fpn_stages.0.refine.norm",
|
|
1453
|
+
"mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter",
|
|
1454
|
+
"mask_head.lay4": "mask_head.fpn_stages.1.refine.conv",
|
|
1455
|
+
"mask_head.gn4": "mask_head.fpn_stages.1.refine.norm",
|
|
1456
|
+
"mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter",
|
|
1457
|
+
"mask_head.lay5": "mask_head.fpn_stages.2.refine.conv",
|
|
1458
|
+
"mask_head.gn5": "mask_head.fpn_stages.2.refine.norm",
|
|
1459
|
+
"mask_head.out_lay": "mask_head.output_conv",
|
|
1460
|
+
}
|
|
1461
|
+
|
|
1381
1462
|
def __init__(self, config: DetrConfig):
|
|
1382
1463
|
super().__init__(config)
|
|
1383
1464
|
|
|
@@ -1386,42 +1467,44 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1386
1467
|
|
|
1387
1468
|
# segmentation head
|
|
1388
1469
|
hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
|
|
1389
|
-
intermediate_channel_sizes = self.detr.model.backbone.
|
|
1470
|
+
intermediate_channel_sizes = self.detr.model.backbone.intermediate_channel_sizes
|
|
1390
1471
|
|
|
1391
1472
|
self.mask_head = DetrMaskHeadSmallConv(
|
|
1392
|
-
hidden_size + number_of_heads,
|
|
1473
|
+
input_channels=hidden_size + number_of_heads,
|
|
1474
|
+
fpn_channels=intermediate_channel_sizes[::-1][-3:],
|
|
1475
|
+
hidden_size=hidden_size,
|
|
1476
|
+
activation_function=config.activation_function,
|
|
1393
1477
|
)
|
|
1394
1478
|
|
|
1395
|
-
self.bbox_attention = DetrMHAttentionMap(
|
|
1396
|
-
hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
|
|
1397
|
-
)
|
|
1479
|
+
self.bbox_attention = DetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0)
|
|
1398
1480
|
# Initialize weights and apply final processing
|
|
1399
1481
|
self.post_init()
|
|
1400
1482
|
|
|
1401
1483
|
@auto_docstring
|
|
1484
|
+
@can_return_tuple
|
|
1402
1485
|
def forward(
|
|
1403
1486
|
self,
|
|
1404
1487
|
pixel_values: torch.FloatTensor,
|
|
1405
|
-
pixel_mask:
|
|
1406
|
-
decoder_attention_mask:
|
|
1407
|
-
encoder_outputs:
|
|
1408
|
-
inputs_embeds:
|
|
1409
|
-
decoder_inputs_embeds:
|
|
1410
|
-
labels:
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
return_dict: Optional[bool] = None,
|
|
1414
|
-
**kwargs,
|
|
1415
|
-
) -> Union[tuple[torch.FloatTensor], DetrSegmentationOutput]:
|
|
1488
|
+
pixel_mask: torch.LongTensor | None = None,
|
|
1489
|
+
decoder_attention_mask: torch.FloatTensor | None = None,
|
|
1490
|
+
encoder_outputs: torch.FloatTensor | None = None,
|
|
1491
|
+
inputs_embeds: torch.FloatTensor | None = None,
|
|
1492
|
+
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
1493
|
+
labels: list[dict] | None = None,
|
|
1494
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1495
|
+
) -> tuple[torch.FloatTensor] | DetrSegmentationOutput:
|
|
1416
1496
|
r"""
|
|
1417
1497
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
1418
|
-
|
|
1498
|
+
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
|
|
1499
|
+
|
|
1500
|
+
- 1 for queries that are **not masked**,
|
|
1501
|
+
- 0 for queries that are **masked**.
|
|
1419
1502
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
1420
|
-
|
|
1421
|
-
|
|
1503
|
+
Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires
|
|
1504
|
+
multi-scale features from the backbone that are not available when bypassing it with inputs_embeds.
|
|
1422
1505
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
1423
1506
|
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
1424
|
-
embedded representation.
|
|
1507
|
+
embedded representation. Useful for tasks that require custom query initialization.
|
|
1425
1508
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
1426
1509
|
Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
|
|
1427
1510
|
dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
|
|
@@ -1434,7 +1517,8 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1434
1517
|
|
|
1435
1518
|
```python
|
|
1436
1519
|
>>> import io
|
|
1437
|
-
>>> import
|
|
1520
|
+
>>> import httpx
|
|
1521
|
+
>>> from io import BytesIO
|
|
1438
1522
|
>>> from PIL import Image
|
|
1439
1523
|
>>> import torch
|
|
1440
1524
|
>>> import numpy
|
|
@@ -1443,7 +1527,8 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1443
1527
|
>>> from transformers.image_transforms import rgb_to_id
|
|
1444
1528
|
|
|
1445
1529
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
1446
|
-
>>>
|
|
1530
|
+
>>> with httpx.stream("GET", url) as response:
|
|
1531
|
+
... image = Image.open(BytesIO(response.read()))
|
|
1447
1532
|
|
|
1448
1533
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
|
|
1449
1534
|
>>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
|
|
@@ -1468,83 +1553,77 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1468
1553
|
5
|
|
1469
1554
|
```"""
|
|
1470
1555
|
|
|
1471
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1472
|
-
|
|
1473
1556
|
batch_size, num_channels, height, width = pixel_values.shape
|
|
1474
1557
|
device = pixel_values.device
|
|
1475
1558
|
|
|
1476
1559
|
if pixel_mask is None:
|
|
1477
1560
|
pixel_mask = torch.ones((batch_size, height, width), device=device)
|
|
1478
1561
|
|
|
1479
|
-
|
|
1480
|
-
|
|
1562
|
+
vision_features = self.detr.model.backbone(pixel_values, pixel_mask)
|
|
1563
|
+
feature_map, mask = vision_features[-1]
|
|
1481
1564
|
|
|
1482
|
-
#
|
|
1483
|
-
feature_map, mask = features[-1]
|
|
1484
|
-
batch_size, num_channels, height, width = feature_map.shape
|
|
1565
|
+
# Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
|
|
1485
1566
|
projected_feature_map = self.detr.model.input_projection(feature_map)
|
|
1486
|
-
|
|
1487
|
-
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
|
1488
|
-
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
|
1489
1567
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
|
1490
|
-
|
|
1491
|
-
|
|
1568
|
+
spatial_position_embeddings = self.detr.model.position_embedding(
|
|
1569
|
+
shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
|
|
1570
|
+
)
|
|
1492
1571
|
flattened_mask = mask.flatten(1)
|
|
1493
1572
|
|
|
1494
|
-
# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
|
|
1495
|
-
# flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
|
|
1496
|
-
# flattened_mask is a Tensor of shape (batch_size, height*width)
|
|
1497
1573
|
if encoder_outputs is None:
|
|
1498
1574
|
encoder_outputs = self.detr.model.encoder(
|
|
1499
1575
|
inputs_embeds=flattened_features,
|
|
1500
1576
|
attention_mask=flattened_mask,
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
output_hidden_states=output_hidden_states,
|
|
1504
|
-
return_dict=return_dict,
|
|
1505
|
-
)
|
|
1506
|
-
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
|
1507
|
-
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
|
1508
|
-
encoder_outputs = BaseModelOutput(
|
|
1509
|
-
last_hidden_state=encoder_outputs[0],
|
|
1510
|
-
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
1511
|
-
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
1577
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1578
|
+
**kwargs,
|
|
1512
1579
|
)
|
|
1513
1580
|
|
|
1514
|
-
|
|
1515
|
-
query_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
|
|
1581
|
+
object_queries_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
|
|
1516
1582
|
batch_size, 1, 1
|
|
1517
1583
|
)
|
|
1518
|
-
queries = torch.zeros_like(query_position_embeddings)
|
|
1519
1584
|
|
|
1520
|
-
#
|
|
1585
|
+
# Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
|
|
1586
|
+
if decoder_inputs_embeds is not None:
|
|
1587
|
+
queries = decoder_inputs_embeds
|
|
1588
|
+
else:
|
|
1589
|
+
queries = torch.zeros_like(object_queries_position_embeddings)
|
|
1590
|
+
|
|
1521
1591
|
decoder_outputs = self.detr.model.decoder(
|
|
1522
1592
|
inputs_embeds=queries,
|
|
1523
|
-
attention_mask=
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
encoder_hidden_states=encoder_outputs
|
|
1593
|
+
attention_mask=decoder_attention_mask,
|
|
1594
|
+
spatial_position_embeddings=spatial_position_embeddings,
|
|
1595
|
+
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
1596
|
+
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
1527
1597
|
encoder_attention_mask=flattened_mask,
|
|
1528
|
-
|
|
1529
|
-
output_hidden_states=output_hidden_states,
|
|
1530
|
-
return_dict=return_dict,
|
|
1598
|
+
**kwargs,
|
|
1531
1599
|
)
|
|
1532
1600
|
|
|
1533
1601
|
sequence_output = decoder_outputs[0]
|
|
1534
1602
|
|
|
1535
|
-
# Sixth, compute logits, pred_boxes and pred_masks
|
|
1536
1603
|
logits = self.detr.class_labels_classifier(sequence_output)
|
|
1537
1604
|
pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()
|
|
1538
1605
|
|
|
1539
|
-
|
|
1540
|
-
|
|
1606
|
+
height, width = feature_map.shape[-2:]
|
|
1607
|
+
memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view(
|
|
1608
|
+
batch_size, self.config.d_model, height, width
|
|
1609
|
+
)
|
|
1610
|
+
attention_mask = flattened_mask.view(batch_size, height, width)
|
|
1541
1611
|
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1612
|
+
if attention_mask is not None:
|
|
1613
|
+
min_dtype = torch.finfo(memory.dtype).min
|
|
1614
|
+
attention_mask = torch.where(
|
|
1615
|
+
attention_mask.unsqueeze(1).unsqueeze(1),
|
|
1616
|
+
torch.tensor(0.0, device=memory.device, dtype=memory.dtype),
|
|
1617
|
+
min_dtype,
|
|
1618
|
+
)
|
|
1546
1619
|
|
|
1547
|
-
|
|
1620
|
+
bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=attention_mask)
|
|
1621
|
+
|
|
1622
|
+
seg_masks = self.mask_head(
|
|
1623
|
+
features=projected_feature_map,
|
|
1624
|
+
attention_masks=bbox_mask,
|
|
1625
|
+
fpn_features=[vision_features[2][0], vision_features[1][0], vision_features[0][0]],
|
|
1626
|
+
)
|
|
1548
1627
|
|
|
1549
1628
|
pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
|
|
1550
1629
|
|
|
@@ -1552,20 +1631,13 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1552
1631
|
if labels is not None:
|
|
1553
1632
|
outputs_class, outputs_coord = None, None
|
|
1554
1633
|
if self.config.auxiliary_loss:
|
|
1555
|
-
intermediate = decoder_outputs.intermediate_hidden_states
|
|
1634
|
+
intermediate = decoder_outputs.intermediate_hidden_states
|
|
1556
1635
|
outputs_class = self.detr.class_labels_classifier(intermediate)
|
|
1557
1636
|
outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
|
|
1558
1637
|
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
|
1559
1638
|
logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
|
|
1560
1639
|
)
|
|
1561
1640
|
|
|
1562
|
-
if not return_dict:
|
|
1563
|
-
if auxiliary_outputs is not None:
|
|
1564
|
-
output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
|
|
1565
|
-
else:
|
|
1566
|
-
output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
|
|
1567
|
-
return ((loss, loss_dict) + output) if loss is not None else output
|
|
1568
|
-
|
|
1569
1641
|
return DetrSegmentationOutput(
|
|
1570
1642
|
loss=loss,
|
|
1571
1643
|
loss_dict=loss_dict,
|
|
@@ -1583,119 +1655,6 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|
|
1583
1655
|
)
|
|
1584
1656
|
|
|
1585
1657
|
|
|
1586
|
-
def _expand(tensor, length: int):
|
|
1587
|
-
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
# taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py
|
|
1591
|
-
class DetrMaskHeadSmallConv(nn.Module):
|
|
1592
|
-
"""
|
|
1593
|
-
Simple convolutional head, using group norm. Upsampling is done using a FPN approach
|
|
1594
|
-
"""
|
|
1595
|
-
|
|
1596
|
-
def __init__(self, dim, fpn_dims, context_dim):
|
|
1597
|
-
super().__init__()
|
|
1598
|
-
|
|
1599
|
-
if dim % 8 != 0:
|
|
1600
|
-
raise ValueError(
|
|
1601
|
-
"The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
|
|
1602
|
-
" GroupNorm is set to 8"
|
|
1603
|
-
)
|
|
1604
|
-
|
|
1605
|
-
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
|
|
1606
|
-
|
|
1607
|
-
self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
|
|
1608
|
-
self.gn1 = nn.GroupNorm(8, dim)
|
|
1609
|
-
self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
|
|
1610
|
-
self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1])
|
|
1611
|
-
self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
|
|
1612
|
-
self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2])
|
|
1613
|
-
self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
|
|
1614
|
-
self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3])
|
|
1615
|
-
self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
|
|
1616
|
-
self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4])
|
|
1617
|
-
self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
|
|
1618
|
-
|
|
1619
|
-
self.dim = dim
|
|
1620
|
-
|
|
1621
|
-
self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
|
|
1622
|
-
self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
|
|
1623
|
-
self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
|
|
1624
|
-
|
|
1625
|
-
for m in self.modules():
|
|
1626
|
-
if isinstance(m, nn.Conv2d):
|
|
1627
|
-
init.kaiming_uniform_(m.weight, a=1)
|
|
1628
|
-
init.constant_(m.bias, 0)
|
|
1629
|
-
|
|
1630
|
-
def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]):
|
|
1631
|
-
# here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with
|
|
1632
|
-
# the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
|
|
1633
|
-
# We expand the projected feature map to match the number of heads.
|
|
1634
|
-
x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
|
|
1635
|
-
|
|
1636
|
-
x = self.lay1(x)
|
|
1637
|
-
x = self.gn1(x)
|
|
1638
|
-
x = nn.functional.relu(x)
|
|
1639
|
-
x = self.lay2(x)
|
|
1640
|
-
x = self.gn2(x)
|
|
1641
|
-
x = nn.functional.relu(x)
|
|
1642
|
-
|
|
1643
|
-
cur_fpn = self.adapter1(fpns[0])
|
|
1644
|
-
if cur_fpn.size(0) != x.size(0):
|
|
1645
|
-
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
|
1646
|
-
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
1647
|
-
x = self.lay3(x)
|
|
1648
|
-
x = self.gn3(x)
|
|
1649
|
-
x = nn.functional.relu(x)
|
|
1650
|
-
|
|
1651
|
-
cur_fpn = self.adapter2(fpns[1])
|
|
1652
|
-
if cur_fpn.size(0) != x.size(0):
|
|
1653
|
-
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
|
1654
|
-
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
1655
|
-
x = self.lay4(x)
|
|
1656
|
-
x = self.gn4(x)
|
|
1657
|
-
x = nn.functional.relu(x)
|
|
1658
|
-
|
|
1659
|
-
cur_fpn = self.adapter3(fpns[2])
|
|
1660
|
-
if cur_fpn.size(0) != x.size(0):
|
|
1661
|
-
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
|
1662
|
-
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
1663
|
-
x = self.lay5(x)
|
|
1664
|
-
x = self.gn5(x)
|
|
1665
|
-
x = nn.functional.relu(x)
|
|
1666
|
-
|
|
1667
|
-
x = self.out_lay(x)
|
|
1668
|
-
return x
|
|
1669
|
-
|
|
1670
|
-
|
|
1671
|
-
class DetrMHAttentionMap(nn.Module):
|
|
1672
|
-
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
|
1673
|
-
|
|
1674
|
-
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
|
|
1675
|
-
super().__init__()
|
|
1676
|
-
self.num_heads = num_heads
|
|
1677
|
-
self.hidden_dim = hidden_dim
|
|
1678
|
-
self.dropout = nn.Dropout(dropout)
|
|
1679
|
-
|
|
1680
|
-
self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
|
1681
|
-
self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
|
1682
|
-
|
|
1683
|
-
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
|
|
1684
|
-
|
|
1685
|
-
def forward(self, q, k, mask: Optional[Tensor] = None):
|
|
1686
|
-
q = self.q_linear(q)
|
|
1687
|
-
k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
|
|
1688
|
-
queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
|
|
1689
|
-
keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
|
|
1690
|
-
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
|
|
1691
|
-
|
|
1692
|
-
if mask is not None:
|
|
1693
|
-
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
|
1694
|
-
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
|
|
1695
|
-
weights = self.dropout(weights)
|
|
1696
|
-
return weights
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
1658
|
__all__ = [
|
|
1700
1659
|
"DetrForObjectDetection",
|
|
1701
1660
|
"DetrForSegmentation",
|