transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1376 @@
|
|
|
1
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
2
|
+
# This file was automatically generated from src/transformers/models/eomt_dinov3/modular_eomt_dinov3.py.
|
|
3
|
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
4
|
+
# the file from the modular. If any change should be done, please apply the change to the
|
|
5
|
+
# modular_eomt_dinov3.py file directly. One of our CI enforces this.
|
|
6
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
7
|
+
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
|
8
|
+
#
|
|
9
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
10
|
+
# you may not use this file except in compliance with the License.
|
|
11
|
+
# You may obtain a copy of the License at
|
|
12
|
+
#
|
|
13
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
14
|
+
#
|
|
15
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
16
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
17
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
18
|
+
# See the License for the specific language governing permissions and
|
|
19
|
+
# limitations under the License.
|
|
20
|
+
|
|
21
|
+
import math
|
|
22
|
+
from collections.abc import Callable
|
|
23
|
+
from dataclasses import dataclass
|
|
24
|
+
from typing import Optional
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
import torch.nn.functional as F
|
|
29
|
+
from torch import Tensor, nn
|
|
30
|
+
|
|
31
|
+
from ... import initialization as init
|
|
32
|
+
from ...activations import ACT2FN
|
|
33
|
+
from ...file_utils import ModelOutput, is_scipy_available, requires_backends
|
|
34
|
+
from ...modeling_layers import GradientCheckpointingLayer
|
|
35
|
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
36
|
+
from ...processing_utils import Unpack
|
|
37
|
+
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
38
|
+
from ...utils import TransformersKwargs, auto_docstring, is_accelerate_available
|
|
39
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
40
|
+
from .configuration_eomt_dinov3 import EomtDinov3Config
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
if is_scipy_available():
|
|
44
|
+
from scipy.optimize import linear_sum_assignment
|
|
45
|
+
|
|
46
|
+
if is_accelerate_available():
|
|
47
|
+
from accelerate import PartialState
|
|
48
|
+
from accelerate.utils import reduce
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def rotate_half(x):
|
|
52
|
+
"""Rotates half the hidden dims of the input."""
|
|
53
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
54
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
55
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def eager_attention_forward(
|
|
59
|
+
module: nn.Module,
|
|
60
|
+
query: torch.Tensor,
|
|
61
|
+
key: torch.Tensor,
|
|
62
|
+
value: torch.Tensor,
|
|
63
|
+
attention_mask: torch.Tensor | None,
|
|
64
|
+
scaling: float | None = None,
|
|
65
|
+
dropout: float = 0.0,
|
|
66
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
67
|
+
):
|
|
68
|
+
if scaling is None:
|
|
69
|
+
scaling = query.size(-1) ** -0.5
|
|
70
|
+
|
|
71
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
72
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
73
|
+
|
|
74
|
+
if attention_mask is not None:
|
|
75
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
76
|
+
attn_weights = attn_weights + attention_mask
|
|
77
|
+
|
|
78
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
79
|
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
80
|
+
|
|
81
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
82
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
83
|
+
|
|
84
|
+
return attn_output, attn_weights
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def apply_rotary_pos_emb(
|
|
88
|
+
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, **kwargs
|
|
89
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
90
|
+
"""Applies Rotary Position Embedding to the query and key tensors, but only to the patch tokens,
|
|
91
|
+
ignoring the prefix tokens (cls token and register tokens).
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
q (`torch.Tensor`): The query tensor.
|
|
95
|
+
k (`torch.Tensor`): The key tensor.
|
|
96
|
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
97
|
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
num_tokens = q.shape[-2]
|
|
104
|
+
num_patches = sin.shape[-2]
|
|
105
|
+
num_prefix_tokens = num_tokens - num_patches # cls token + register tokens
|
|
106
|
+
|
|
107
|
+
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
|
|
108
|
+
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
|
|
109
|
+
|
|
110
|
+
# apply rope only to patch tokens
|
|
111
|
+
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
|
|
112
|
+
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
|
|
113
|
+
|
|
114
|
+
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
|
|
115
|
+
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
|
|
116
|
+
|
|
117
|
+
return q, k
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class EomtDinov3Attention(nn.Module):
|
|
121
|
+
"""
|
|
122
|
+
Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def __init__(self, config: EomtDinov3Config):
|
|
126
|
+
super().__init__()
|
|
127
|
+
self.config = config
|
|
128
|
+
self.embed_dim = config.hidden_size
|
|
129
|
+
self.num_heads = config.num_attention_heads
|
|
130
|
+
self.head_dim = self.embed_dim // self.num_heads
|
|
131
|
+
self.is_causal = False
|
|
132
|
+
|
|
133
|
+
self.scaling = self.head_dim**-0.5
|
|
134
|
+
self.is_causal = False
|
|
135
|
+
|
|
136
|
+
self.dropout = config.attention_dropout
|
|
137
|
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias)
|
|
138
|
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias)
|
|
139
|
+
|
|
140
|
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias)
|
|
141
|
+
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias)
|
|
142
|
+
|
|
143
|
+
def forward(
|
|
144
|
+
self,
|
|
145
|
+
hidden_states: torch.Tensor,
|
|
146
|
+
attention_mask: torch.Tensor | None = None,
|
|
147
|
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
148
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
149
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
150
|
+
"""Input shape: Batch x Time x Channel"""
|
|
151
|
+
|
|
152
|
+
batch_size, patches, _ = hidden_states.size()
|
|
153
|
+
|
|
154
|
+
query_states = self.q_proj(hidden_states)
|
|
155
|
+
key_states = self.k_proj(hidden_states)
|
|
156
|
+
value_states = self.v_proj(hidden_states)
|
|
157
|
+
|
|
158
|
+
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
|
159
|
+
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
|
160
|
+
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
|
161
|
+
|
|
162
|
+
cos, sin = position_embeddings
|
|
163
|
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
164
|
+
|
|
165
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
166
|
+
self.config._attn_implementation, eager_attention_forward
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
attn_output, attn_weights = attention_interface(
|
|
170
|
+
self,
|
|
171
|
+
query_states,
|
|
172
|
+
key_states,
|
|
173
|
+
value_states,
|
|
174
|
+
attention_mask,
|
|
175
|
+
dropout=0.0 if not self.training else self.dropout,
|
|
176
|
+
scaling=self.scaling,
|
|
177
|
+
**kwargs,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
|
181
|
+
attn_output = self.o_proj(attn_output)
|
|
182
|
+
|
|
183
|
+
return attn_output, attn_weights
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class EomtDinov3Embeddings(nn.Module):
|
|
187
|
+
"""
|
|
188
|
+
Construct the CLS token, mask token, position and patch embeddings.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
def __init__(self, config: EomtDinov3Config):
|
|
192
|
+
super().__init__()
|
|
193
|
+
self.config = config
|
|
194
|
+
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
|
195
|
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
|
196
|
+
self.register_tokens = nn.Parameter(torch.empty(1, config.num_register_tokens, config.hidden_size))
|
|
197
|
+
self.patch_embeddings = nn.Conv2d(
|
|
198
|
+
config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
|
|
199
|
+
)
|
|
200
|
+
self.num_prefix_tokens = 1 + config.num_register_tokens
|
|
201
|
+
|
|
202
|
+
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor:
|
|
203
|
+
batch_size = pixel_values.shape[0]
|
|
204
|
+
target_dtype = self.patch_embeddings.weight.dtype
|
|
205
|
+
|
|
206
|
+
# (batch_size, num_channels, height, width) -> (batch_size, num_patches, hidden_size)
|
|
207
|
+
patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
|
|
208
|
+
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
|
209
|
+
|
|
210
|
+
if bool_masked_pos is not None:
|
|
211
|
+
mask_token = self.mask_token.to(patch_embeddings.dtype)
|
|
212
|
+
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
|
|
213
|
+
|
|
214
|
+
# Add CLS and register tokens
|
|
215
|
+
cls_token = self.cls_token.expand(batch_size, -1, -1)
|
|
216
|
+
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
|
|
217
|
+
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
|
|
218
|
+
|
|
219
|
+
return embeddings
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
|
223
|
+
"""
|
|
224
|
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
225
|
+
|
|
226
|
+
"""
|
|
227
|
+
if drop_prob == 0.0 or not training:
|
|
228
|
+
return input
|
|
229
|
+
keep_prob = 1 - drop_prob
|
|
230
|
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
|
231
|
+
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
|
232
|
+
random_tensor.floor_() # binarize
|
|
233
|
+
output = input.div(keep_prob) * random_tensor
|
|
234
|
+
return output
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class EomtDinov3DropPath(nn.Module):
|
|
238
|
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
|
239
|
+
|
|
240
|
+
def __init__(self, drop_prob: float | None = None) -> None:
|
|
241
|
+
super().__init__()
|
|
242
|
+
self.drop_prob = drop_prob
|
|
243
|
+
|
|
244
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
245
|
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
|
246
|
+
|
|
247
|
+
def extra_repr(self) -> str:
|
|
248
|
+
return f"p={self.drop_prob}"
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class EomtDinov3MLP(nn.Module):
|
|
252
|
+
def __init__(self, config):
|
|
253
|
+
super().__init__()
|
|
254
|
+
self.config = config
|
|
255
|
+
self.hidden_size = config.hidden_size
|
|
256
|
+
self.intermediate_size = config.intermediate_size
|
|
257
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
|
258
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
|
259
|
+
self.act_fn = ACT2FN[config.hidden_act]
|
|
260
|
+
|
|
261
|
+
def forward(self, x):
|
|
262
|
+
return self.down_proj(self.act_fn(self.up_proj(x)))
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class EomtDinov3GatedMLP(nn.Module):
|
|
266
|
+
def __init__(self, config):
|
|
267
|
+
super().__init__()
|
|
268
|
+
self.config = config
|
|
269
|
+
self.hidden_size = config.hidden_size
|
|
270
|
+
self.intermediate_size = config.intermediate_size
|
|
271
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
|
272
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
|
273
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
|
274
|
+
self.act_fn = ACT2FN[config.hidden_act]
|
|
275
|
+
|
|
276
|
+
def forward(self, x):
|
|
277
|
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
278
|
+
return down_proj
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class EomtDinov3Layer(GradientCheckpointingLayer):
|
|
282
|
+
"""This corresponds to the Block class in the original implementation."""
|
|
283
|
+
|
|
284
|
+
def __init__(self, config: EomtDinov3Config):
|
|
285
|
+
super().__init__()
|
|
286
|
+
|
|
287
|
+
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
288
|
+
self.attention = EomtDinov3Attention(config)
|
|
289
|
+
self.layer_scale1 = EomtDinov3LayerScale(config)
|
|
290
|
+
self.drop_path = EomtDinov3DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
|
291
|
+
|
|
292
|
+
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
293
|
+
|
|
294
|
+
if config.use_gated_mlp:
|
|
295
|
+
self.mlp = EomtDinov3GatedMLP(config)
|
|
296
|
+
else:
|
|
297
|
+
self.mlp = EomtDinov3MLP(config)
|
|
298
|
+
self.layer_scale2 = EomtDinov3LayerScale(config)
|
|
299
|
+
|
|
300
|
+
def forward(
|
|
301
|
+
self,
|
|
302
|
+
hidden_states: torch.Tensor,
|
|
303
|
+
attention_mask: torch.Tensor | None = None,
|
|
304
|
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
305
|
+
) -> torch.Tensor:
|
|
306
|
+
# Attention with residual connection
|
|
307
|
+
residual = hidden_states
|
|
308
|
+
hidden_states = self.norm1(hidden_states)
|
|
309
|
+
hidden_states, _ = self.attention(
|
|
310
|
+
hidden_states,
|
|
311
|
+
attention_mask=attention_mask,
|
|
312
|
+
position_embeddings=position_embeddings,
|
|
313
|
+
)
|
|
314
|
+
hidden_states = self.layer_scale1(hidden_states)
|
|
315
|
+
hidden_states = self.drop_path(hidden_states) + residual
|
|
316
|
+
|
|
317
|
+
# MLP with residual connection
|
|
318
|
+
residual = hidden_states
|
|
319
|
+
hidden_states = self.norm2(hidden_states)
|
|
320
|
+
hidden_states = self.mlp(hidden_states)
|
|
321
|
+
hidden_states = self.layer_scale2(hidden_states)
|
|
322
|
+
hidden_states = self.drop_path(hidden_states) + residual
|
|
323
|
+
|
|
324
|
+
return hidden_states
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class EomtDinov3LayerScale(nn.Module):
|
|
328
|
+
def __init__(self, config) -> None:
|
|
329
|
+
super().__init__()
|
|
330
|
+
self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
|
|
331
|
+
|
|
332
|
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
|
333
|
+
return hidden_state * self.lambda1
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
@compile_compatible_method_lru_cache(maxsize=32)
|
|
337
|
+
def get_patches_center_coordinates(
|
|
338
|
+
num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device
|
|
339
|
+
) -> torch.Tensor:
|
|
340
|
+
"""
|
|
341
|
+
Computes the 2D coordinates of the centers of image patches, normalized to the range [-1, +1].
|
|
342
|
+
The center of each patch is exactly halfway between its top-left and bottom-right corners.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
num_patches_h (int): Number of patches along the vertical (height) axis.
|
|
346
|
+
num_patches_w (int): Number of patches along the horizontal (width) axis.
|
|
347
|
+
dtype (torch.dtype): The desired data type of the returned tensor.
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
torch.Tensor: A tensor of shape (height * width, 2), where each row contains the (y, x)
|
|
351
|
+
coordinates of a patch center, normalized to [-1, +1].
|
|
352
|
+
"""
|
|
353
|
+
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
|
|
354
|
+
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
|
|
355
|
+
coords_h = coords_h / num_patches_h
|
|
356
|
+
coords_w = coords_w / num_patches_w
|
|
357
|
+
# (height, width, 2) -> (height * width, 2)
|
|
358
|
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
|
359
|
+
coords = coords.flatten(0, 1)
|
|
360
|
+
# Shift range [0, 1] to [-1, +1]
|
|
361
|
+
coords = 2.0 * coords - 1.0
|
|
362
|
+
return coords
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def augment_patches_center_coordinates(
|
|
366
|
+
coords: torch.Tensor,
|
|
367
|
+
shift: float | None = None,
|
|
368
|
+
jitter: float | None = None,
|
|
369
|
+
rescale: float | None = None,
|
|
370
|
+
) -> torch.Tensor:
|
|
371
|
+
# Shift coords by adding a uniform value in [-shift, shift]
|
|
372
|
+
if shift is not None:
|
|
373
|
+
shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
|
|
374
|
+
shift_hw = shift_hw.uniform_(-shift, shift)
|
|
375
|
+
coords = coords + shift_hw
|
|
376
|
+
|
|
377
|
+
# Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
|
|
378
|
+
if jitter is not None:
|
|
379
|
+
jitter_range = np.log(jitter)
|
|
380
|
+
jitter_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
|
|
381
|
+
jitter_hw = jitter_hw.uniform_(-jitter_range, jitter_range).exp()
|
|
382
|
+
coords = coords * jitter_hw
|
|
383
|
+
|
|
384
|
+
# Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
|
|
385
|
+
if rescale is not None:
|
|
386
|
+
rescale_range = np.log(rescale)
|
|
387
|
+
rescale_hw = torch.empty(1, device=coords.device, dtype=coords.dtype)
|
|
388
|
+
rescale_hw = rescale_hw.uniform_(-rescale_range, rescale_range).exp()
|
|
389
|
+
coords = coords * rescale_hw
|
|
390
|
+
|
|
391
|
+
return coords
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class EomtDinov3RotaryEmbedding(nn.Module):
|
|
395
|
+
inv_freq: Tensor
|
|
396
|
+
|
|
397
|
+
def __init__(self, config: EomtDinov3Config, device=None):
|
|
398
|
+
super().__init__()
|
|
399
|
+
self.config = config
|
|
400
|
+
|
|
401
|
+
self.rope_type = self.config.rope_parameters["rope_type"]
|
|
402
|
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
|
403
|
+
if self.rope_type != "default":
|
|
404
|
+
raise ValueError("`EomtDinov3` only supports `default` RoPE! Please check your `rope_type`")
|
|
405
|
+
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
406
|
+
|
|
407
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
408
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
409
|
+
|
|
410
|
+
def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
411
|
+
_, _, height, width = pixel_values.shape
|
|
412
|
+
num_patches_h = height // self.config.patch_size
|
|
413
|
+
num_patches_w = width // self.config.patch_size
|
|
414
|
+
|
|
415
|
+
device = pixel_values.device
|
|
416
|
+
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
|
|
417
|
+
|
|
418
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
419
|
+
# Although we could precompute static patch_coords from image_size and patch_size in the config,
|
|
420
|
+
# the model was trained with random_scale, so it can process images of varying sizes.
|
|
421
|
+
# Therefore, it's better to compute patch_coords dynamically (with lru_cache).
|
|
422
|
+
patch_coords = get_patches_center_coordinates(
|
|
423
|
+
num_patches_h, num_patches_w, dtype=torch.float32, device=device
|
|
424
|
+
)
|
|
425
|
+
if self.training:
|
|
426
|
+
patch_coords = augment_patches_center_coordinates(
|
|
427
|
+
patch_coords,
|
|
428
|
+
shift=self.config.pos_embed_shift,
|
|
429
|
+
jitter=self.config.pos_embed_jitter,
|
|
430
|
+
rescale=self.config.pos_embed_rescale,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# (height * width, 2, head_dim / 4) -> (height * width, head_dim / 2) -> (height * width, head_dim)
|
|
434
|
+
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
|
|
435
|
+
angles = angles.flatten(1, 2)
|
|
436
|
+
angles = angles.tile(2)
|
|
437
|
+
|
|
438
|
+
cos = torch.cos(angles)
|
|
439
|
+
sin = torch.sin(angles)
|
|
440
|
+
|
|
441
|
+
dtype = pixel_values.dtype
|
|
442
|
+
return cos.to(dtype=dtype), sin.to(dtype=dtype)
|
|
443
|
+
|
|
444
|
+
@staticmethod
|
|
445
|
+
def compute_default_rope_parameters(
|
|
446
|
+
config: EomtDinov3Config | None = None,
|
|
447
|
+
device: Optional["torch.device"] = None,
|
|
448
|
+
seq_len: int | None = None,
|
|
449
|
+
) -> torch.Tensor:
|
|
450
|
+
"""
|
|
451
|
+
Computes the inverse frequencies according to the original RoPE implementation
|
|
452
|
+
Args:
|
|
453
|
+
config ([`~transformers.PreTrainedConfig`]):
|
|
454
|
+
The model configuration.
|
|
455
|
+
device (`torch.device`):
|
|
456
|
+
The device to use for initialization of the inverse frequencies.
|
|
457
|
+
seq_len (`int`, *optional*):
|
|
458
|
+
The current sequence length. Unused for this type of RoPE.
|
|
459
|
+
Returns:
|
|
460
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
|
461
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
|
462
|
+
"""
|
|
463
|
+
base = config.rope_parameters["rope_theta"]
|
|
464
|
+
head_dim = config.hidden_size // config.num_attention_heads
|
|
465
|
+
|
|
466
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
|
467
|
+
|
|
468
|
+
# Compute the inverse frequencies
|
|
469
|
+
inv_freq = 1 / base ** torch.arange(0, 1, 4 / head_dim, dtype=torch.float32, device=device)
|
|
470
|
+
return inv_freq, attention_factor
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
|
|
474
|
+
def sample_point(
|
|
475
|
+
input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs
|
|
476
|
+
) -> torch.Tensor:
|
|
477
|
+
"""
|
|
478
|
+
A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
input_features (`torch.Tensor` of shape (batch_size, channels, height, width)):
|
|
482
|
+
A tensor that contains features map on a height * width grid
|
|
483
|
+
point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,:
|
|
484
|
+
2)):
|
|
485
|
+
A tensor that contains [0, 1] * [0, 1] normalized point coordinates
|
|
486
|
+
add_dim (`bool`):
|
|
487
|
+
boolean value to keep track of added dimension
|
|
488
|
+
|
|
489
|
+
Returns:
|
|
490
|
+
point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels,
|
|
491
|
+
height_grid, width_grid):
|
|
492
|
+
A tensor that contains features for points in `point_coordinates`.
|
|
493
|
+
"""
|
|
494
|
+
if point_coordinates.dim() == 3:
|
|
495
|
+
add_dim = True
|
|
496
|
+
point_coordinates = point_coordinates.unsqueeze(2)
|
|
497
|
+
|
|
498
|
+
# use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation
|
|
499
|
+
point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs)
|
|
500
|
+
if add_dim:
|
|
501
|
+
point_features = point_features.squeeze(3)
|
|
502
|
+
|
|
503
|
+
return point_features
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
|
|
507
|
+
"""
|
|
508
|
+
A pair wise version of the dice loss, see `dice_loss` for usage.
|
|
509
|
+
|
|
510
|
+
Args:
|
|
511
|
+
inputs (`torch.Tensor`):
|
|
512
|
+
A tensor representing a mask
|
|
513
|
+
labels (`torch.Tensor`):
|
|
514
|
+
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
|
|
515
|
+
(0 for the negative class and 1 for the positive class).
|
|
516
|
+
|
|
517
|
+
Returns:
|
|
518
|
+
`torch.Tensor`: The computed loss between each pairs.
|
|
519
|
+
"""
|
|
520
|
+
inputs = inputs.sigmoid().flatten(1)
|
|
521
|
+
numerator = 2 * torch.matmul(inputs, labels.T)
|
|
522
|
+
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
|
|
523
|
+
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
|
|
524
|
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
|
525
|
+
return loss
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
529
|
+
r"""
|
|
530
|
+
A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage.
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
inputs (`torch.Tensor`):
|
|
534
|
+
A tensor representing a mask.
|
|
535
|
+
labels (`torch.Tensor`):
|
|
536
|
+
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
|
|
537
|
+
(0 for the negative class and 1 for the positive class).
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
loss (`torch.Tensor`): The computed loss between each pairs.
|
|
541
|
+
"""
|
|
542
|
+
|
|
543
|
+
height_and_width = inputs.shape[1]
|
|
544
|
+
|
|
545
|
+
criterion = nn.BCEWithLogitsLoss(reduction="none")
|
|
546
|
+
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
|
|
547
|
+
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
|
|
548
|
+
|
|
549
|
+
loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T)
|
|
550
|
+
loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T)
|
|
551
|
+
loss = loss_pos + loss_neg
|
|
552
|
+
return loss
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
# Adapted from https://github.com/facebookresearch/EomtDinov3/blob/main/eomt_dinov3/modeling/matcher.py
|
|
556
|
+
class EomtDinov3HungarianMatcher(nn.Module):
|
|
557
|
+
"""This class computes an assignment between the labels and the predictions of the network.
|
|
558
|
+
|
|
559
|
+
For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
|
|
560
|
+
predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
|
561
|
+
un-matched (and thus treated as non-objects).
|
|
562
|
+
"""
|
|
563
|
+
|
|
564
|
+
def __init__(
|
|
565
|
+
self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544
|
|
566
|
+
):
|
|
567
|
+
"""Creates the matcher
|
|
568
|
+
|
|
569
|
+
Params:
|
|
570
|
+
cost_class (`float`, *optional*, defaults to 1.0):
|
|
571
|
+
Relative weight of the classification error in the matching cost.
|
|
572
|
+
cost_mask (`float`, *optional*, defaults to 1.0):
|
|
573
|
+
This is the relative weight of the focal loss of the binary mask in the matching cost.
|
|
574
|
+
cost_dice (`float`, *optional*, defaults to 1.0):
|
|
575
|
+
This is the relative weight of the dice loss of the binary mask in the matching cost.
|
|
576
|
+
num_points (`int`, *optional*, defaults to 12544):
|
|
577
|
+
No. of points to sample on which the mask loss will be calculated. The same set of K points are
|
|
578
|
+
uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite
|
|
579
|
+
matching.
|
|
580
|
+
"""
|
|
581
|
+
super().__init__()
|
|
582
|
+
if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
|
|
583
|
+
raise ValueError("All costs can't be 0")
|
|
584
|
+
|
|
585
|
+
self.num_points = num_points
|
|
586
|
+
self.cost_class = cost_class
|
|
587
|
+
self.cost_mask = cost_mask
|
|
588
|
+
self.cost_dice = cost_dice
|
|
589
|
+
|
|
590
|
+
@torch.no_grad()
|
|
591
|
+
def forward(
|
|
592
|
+
self,
|
|
593
|
+
masks_queries_logits: torch.Tensor,
|
|
594
|
+
class_queries_logits: torch.Tensor,
|
|
595
|
+
mask_labels: torch.Tensor,
|
|
596
|
+
class_labels: torch.Tensor,
|
|
597
|
+
) -> list[tuple[Tensor]]:
|
|
598
|
+
"""
|
|
599
|
+
Params:
|
|
600
|
+
masks_queries_logits (`torch.Tensor`):
|
|
601
|
+
A tensor of dim `batch_size, num_queries, num_labels` with the classification logits.
|
|
602
|
+
class_queries_logits (`torch.Tensor`):
|
|
603
|
+
A tensor of dim `batch_size, num_queries, height, width` with the predicted masks.
|
|
604
|
+
class_labels (`torch.Tensor`):
|
|
605
|
+
A tensor of dim `num_target_boxes` (where num_target_boxes is the number of ground-truth objects in the
|
|
606
|
+
target) containing the class labels.
|
|
607
|
+
mask_labels (`torch.Tensor`):
|
|
608
|
+
A tensor of dim `num_target_boxes, height, width` containing the target masks.
|
|
609
|
+
|
|
610
|
+
Returns:
|
|
611
|
+
matched_indices (`list[tuple[Tensor]]`): A list of size batch_size, containing tuples of (index_i, index_j)
|
|
612
|
+
where:
|
|
613
|
+
- index_i is the indices of the selected predictions (in order)
|
|
614
|
+
- index_j is the indices of the corresponding selected labels (in order)
|
|
615
|
+
For each batch element, it holds:
|
|
616
|
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
|
|
617
|
+
"""
|
|
618
|
+
indices: list[tuple[np.array]] = []
|
|
619
|
+
|
|
620
|
+
# iterate through batch size
|
|
621
|
+
batch_size = masks_queries_logits.shape[0]
|
|
622
|
+
for i in range(batch_size):
|
|
623
|
+
pred_probs = class_queries_logits[i].softmax(-1)
|
|
624
|
+
pred_mask = masks_queries_logits[i]
|
|
625
|
+
|
|
626
|
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be omitted.
|
|
627
|
+
cost_class = -pred_probs[:, class_labels[i]]
|
|
628
|
+
target_mask = mask_labels[i].to(pred_mask)
|
|
629
|
+
target_mask = target_mask[:, None]
|
|
630
|
+
pred_mask = pred_mask[:, None]
|
|
631
|
+
|
|
632
|
+
# Sample ground truth and predicted masks
|
|
633
|
+
point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device)
|
|
634
|
+
|
|
635
|
+
target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1)
|
|
636
|
+
target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1)
|
|
637
|
+
|
|
638
|
+
pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1)
|
|
639
|
+
pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1)
|
|
640
|
+
|
|
641
|
+
# compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels)
|
|
642
|
+
cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask)
|
|
643
|
+
# Compute the dice loss between each mask pairs -> shape (num_queries, num_labels)
|
|
644
|
+
cost_dice = pair_wise_dice_loss(pred_mask, target_mask)
|
|
645
|
+
# final cost matrix
|
|
646
|
+
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
|
|
647
|
+
# eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible``
|
|
648
|
+
cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10))
|
|
649
|
+
cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10))
|
|
650
|
+
cost_matrix = torch.nan_to_num(cost_matrix, 0)
|
|
651
|
+
# do the assignment using the hungarian algorithm in scipy
|
|
652
|
+
assigned_indices: tuple[np.array] = linear_sum_assignment(cost_matrix.cpu())
|
|
653
|
+
indices.append(assigned_indices)
|
|
654
|
+
|
|
655
|
+
# It could be stacked in one tensor
|
|
656
|
+
matched_indices = [
|
|
657
|
+
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
|
|
658
|
+
]
|
|
659
|
+
return matched_indices
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
|
|
663
|
+
r"""
|
|
664
|
+
Compute the DICE loss, similar to generalized IOU for masks as follows:
|
|
665
|
+
|
|
666
|
+
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$
|
|
667
|
+
|
|
668
|
+
In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow
|
|
669
|
+
|
|
670
|
+
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$
|
|
671
|
+
|
|
672
|
+
Args:
|
|
673
|
+
inputs (`torch.Tensor`):
|
|
674
|
+
A tensor representing a mask.
|
|
675
|
+
labels (`torch.Tensor`):
|
|
676
|
+
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
|
|
677
|
+
(0 for the negative class and 1 for the positive class).
|
|
678
|
+
num_masks (`int`):
|
|
679
|
+
The number of masks present in the current batch, used for normalization.
|
|
680
|
+
|
|
681
|
+
Returns:
|
|
682
|
+
`torch.Tensor`: The computed loss.
|
|
683
|
+
"""
|
|
684
|
+
probs = inputs.sigmoid().flatten(1)
|
|
685
|
+
numerator = 2 * (probs * labels).sum(-1)
|
|
686
|
+
denominator = probs.sum(-1) + labels.sum(-1)
|
|
687
|
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
|
688
|
+
loss = loss.sum() / num_masks
|
|
689
|
+
return loss
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
|
|
693
|
+
r"""
|
|
694
|
+
Args:
|
|
695
|
+
inputs (`torch.Tensor`):
|
|
696
|
+
A float tensor of arbitrary shape.
|
|
697
|
+
labels (`torch.Tensor`):
|
|
698
|
+
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
|
|
699
|
+
(0 for the negative class and 1 for the positive class).
|
|
700
|
+
|
|
701
|
+
Returns:
|
|
702
|
+
loss (`torch.Tensor`): The computed loss.
|
|
703
|
+
"""
|
|
704
|
+
criterion = nn.BCEWithLogitsLoss(reduction="none")
|
|
705
|
+
cross_entropy_loss = criterion(inputs, labels)
|
|
706
|
+
|
|
707
|
+
loss = cross_entropy_loss.mean(1).sum() / num_masks
|
|
708
|
+
return loss
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
# Adapted from https://github.com/facebookresearch/EomtDinov3/blob/main/eomt_dinov3/modeling/criterion.py
|
|
712
|
+
class EomtDinov3Loss(nn.Module):
|
|
713
|
+
def __init__(self, config: EomtDinov3Config, weight_dict: dict[str, float]):
|
|
714
|
+
"""
|
|
715
|
+
The EomtDinov3 Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we
|
|
716
|
+
compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair
|
|
717
|
+
of matched ground-truth / prediction (supervise class and mask)
|
|
718
|
+
|
|
719
|
+
Args:
|
|
720
|
+
config (`EomtDinov3Config`):
|
|
721
|
+
The configuration for EomtDinov3 model also containing loss calculation specific parameters.
|
|
722
|
+
weight_dict (`dict[str, float]`):
|
|
723
|
+
A dictionary of weights to be applied to the different losses.
|
|
724
|
+
"""
|
|
725
|
+
super().__init__()
|
|
726
|
+
requires_backends(self, ["scipy"])
|
|
727
|
+
self.num_labels = config.num_labels
|
|
728
|
+
self.weight_dict = weight_dict
|
|
729
|
+
|
|
730
|
+
# Weight to apply to the null class
|
|
731
|
+
self.eos_coef = config.no_object_weight
|
|
732
|
+
empty_weight = torch.ones(self.num_labels + 1)
|
|
733
|
+
empty_weight[-1] = self.eos_coef
|
|
734
|
+
self.register_buffer("empty_weight", empty_weight)
|
|
735
|
+
|
|
736
|
+
# pointwise mask loss parameters
|
|
737
|
+
self.num_points = config.train_num_points
|
|
738
|
+
self.oversample_ratio = config.oversample_ratio
|
|
739
|
+
self.importance_sample_ratio = config.importance_sample_ratio
|
|
740
|
+
|
|
741
|
+
self.matcher = EomtDinov3HungarianMatcher(
|
|
742
|
+
cost_class=config.class_weight,
|
|
743
|
+
cost_dice=config.dice_weight,
|
|
744
|
+
cost_mask=config.mask_weight,
|
|
745
|
+
num_points=self.num_points,
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
def _max_by_axis(self, sizes: list[list[int]]) -> list[int]:
|
|
749
|
+
maxes = sizes[0]
|
|
750
|
+
for sublist in sizes[1:]:
|
|
751
|
+
for index, item in enumerate(sublist):
|
|
752
|
+
maxes[index] = max(maxes[index], item)
|
|
753
|
+
return maxes
|
|
754
|
+
|
|
755
|
+
# Adapted from nested_tensor_from_tensor_list() in original implementation
|
|
756
|
+
def _pad_images_to_max_in_batch(self, tensors: list[Tensor]) -> tuple[Tensor, Tensor]:
|
|
757
|
+
# get the maximum size in the batch
|
|
758
|
+
max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
|
|
759
|
+
# compute final size
|
|
760
|
+
batch_shape = [len(tensors)] + max_size
|
|
761
|
+
batch_size, _, height, width = batch_shape
|
|
762
|
+
dtype = tensors[0].dtype
|
|
763
|
+
device = tensors[0].device
|
|
764
|
+
padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
|
|
765
|
+
padding_masks = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
|
|
766
|
+
# pad the tensors to the size of the biggest one
|
|
767
|
+
for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
|
|
768
|
+
padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
|
|
769
|
+
padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
|
|
770
|
+
|
|
771
|
+
return padded_tensors, padding_masks
|
|
772
|
+
|
|
773
|
+
def loss_labels(
|
|
774
|
+
self, class_queries_logits: Tensor, class_labels: list[Tensor], indices: tuple[np.array]
|
|
775
|
+
) -> dict[str, Tensor]:
|
|
776
|
+
"""Compute the losses related to the labels using cross entropy.
|
|
777
|
+
|
|
778
|
+
Args:
|
|
779
|
+
class_queries_logits (`torch.Tensor`):
|
|
780
|
+
A tensor of shape `batch_size, num_queries, num_labels`
|
|
781
|
+
class_labels (`list[torch.Tensor]`):
|
|
782
|
+
List of class labels of shape `(labels)`.
|
|
783
|
+
indices (`tuple[np.array])`:
|
|
784
|
+
The indices computed by the Hungarian matcher.
|
|
785
|
+
|
|
786
|
+
Returns:
|
|
787
|
+
`dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
|
|
788
|
+
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
|
|
789
|
+
"""
|
|
790
|
+
pred_logits = class_queries_logits
|
|
791
|
+
batch_size, num_queries, _ = pred_logits.shape
|
|
792
|
+
criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
|
|
793
|
+
idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries)
|
|
794
|
+
target_classes_o = torch.cat(
|
|
795
|
+
[target[j] for target, (_, j) in zip(class_labels, indices)]
|
|
796
|
+
) # shape of (batch_size, num_queries)
|
|
797
|
+
target_classes = torch.full(
|
|
798
|
+
(batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
|
|
799
|
+
)
|
|
800
|
+
target_classes[idx] = target_classes_o
|
|
801
|
+
# Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries)
|
|
802
|
+
pred_logits_transposed = pred_logits.transpose(1, 2)
|
|
803
|
+
loss_ce = criterion(pred_logits_transposed, target_classes)
|
|
804
|
+
losses = {"loss_cross_entropy": loss_ce}
|
|
805
|
+
return losses
|
|
806
|
+
|
|
807
|
+
def loss_masks(
|
|
808
|
+
self,
|
|
809
|
+
masks_queries_logits: torch.Tensor,
|
|
810
|
+
mask_labels: list[torch.Tensor],
|
|
811
|
+
indices: tuple[np.array],
|
|
812
|
+
num_masks: int,
|
|
813
|
+
) -> dict[str, torch.Tensor]:
|
|
814
|
+
"""Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss.
|
|
815
|
+
|
|
816
|
+
Args:
|
|
817
|
+
masks_queries_logits (`torch.Tensor`):
|
|
818
|
+
A tensor of shape `(batch_size, num_queries, height, width)`.
|
|
819
|
+
mask_labels (`torch.Tensor`):
|
|
820
|
+
List of mask labels of shape `(labels, height, width)`.
|
|
821
|
+
indices (`tuple[np.array])`:
|
|
822
|
+
The indices computed by the Hungarian matcher.
|
|
823
|
+
num_masks (`int)`:
|
|
824
|
+
The number of masks, used for normalization.
|
|
825
|
+
|
|
826
|
+
Returns:
|
|
827
|
+
losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys:
|
|
828
|
+
- **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth.
|
|
829
|
+
masks.
|
|
830
|
+
- **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth,
|
|
831
|
+
masks.
|
|
832
|
+
"""
|
|
833
|
+
src_idx = self._get_predictions_permutation_indices(indices)
|
|
834
|
+
tgt_idx = self._get_targets_permutation_indices(indices)
|
|
835
|
+
# shape (batch_size * num_queries, height, width)
|
|
836
|
+
pred_masks = masks_queries_logits[src_idx]
|
|
837
|
+
# shape (batch_size, num_queries, height, width)
|
|
838
|
+
# pad all and stack the targets to the num_labels dimension
|
|
839
|
+
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
|
|
840
|
+
target_masks = target_masks[tgt_idx]
|
|
841
|
+
|
|
842
|
+
# No need to upsample predictions as we are using normalized coordinates
|
|
843
|
+
pred_masks = pred_masks[:, None]
|
|
844
|
+
target_masks = target_masks[:, None]
|
|
845
|
+
|
|
846
|
+
# Sample point coordinates
|
|
847
|
+
with torch.no_grad():
|
|
848
|
+
point_coordinates = self.sample_points_using_uncertainty(
|
|
849
|
+
pred_masks,
|
|
850
|
+
lambda logits: self.calculate_uncertainty(logits),
|
|
851
|
+
self.num_points,
|
|
852
|
+
self.oversample_ratio,
|
|
853
|
+
self.importance_sample_ratio,
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1)
|
|
857
|
+
|
|
858
|
+
point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1)
|
|
859
|
+
|
|
860
|
+
losses = {
|
|
861
|
+
"loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks),
|
|
862
|
+
"loss_dice": dice_loss(point_logits, point_labels, num_masks),
|
|
863
|
+
}
|
|
864
|
+
|
|
865
|
+
del pred_masks
|
|
866
|
+
del target_masks
|
|
867
|
+
return losses
|
|
868
|
+
|
|
869
|
+
def _get_predictions_permutation_indices(self, indices):
|
|
870
|
+
# Permute predictions following indices
|
|
871
|
+
batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
872
|
+
predictions_indices = torch.cat([src for (src, _) in indices])
|
|
873
|
+
return batch_indices, predictions_indices
|
|
874
|
+
|
|
875
|
+
def _get_targets_permutation_indices(self, indices):
|
|
876
|
+
# Permute labels following indices
|
|
877
|
+
batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
|
878
|
+
target_indices = torch.cat([tgt for (_, tgt) in indices])
|
|
879
|
+
return batch_indices, target_indices
|
|
880
|
+
|
|
881
|
+
def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor:
|
|
882
|
+
"""
|
|
883
|
+
In EomtDinov3 paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits'
|
|
884
|
+
for the foreground class in `classes`.
|
|
885
|
+
|
|
886
|
+
Args:
|
|
887
|
+
logits (`torch.Tensor`):
|
|
888
|
+
A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is:
|
|
889
|
+
the number of foreground classes. The values are logits.
|
|
890
|
+
|
|
891
|
+
Returns:
|
|
892
|
+
scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most
|
|
893
|
+
uncertain locations having the highest uncertainty score.
|
|
894
|
+
"""
|
|
895
|
+
uncertainty_scores = -(torch.abs(logits))
|
|
896
|
+
return uncertainty_scores
|
|
897
|
+
|
|
898
|
+
def sample_points_using_uncertainty(
|
|
899
|
+
self,
|
|
900
|
+
logits: torch.Tensor,
|
|
901
|
+
uncertainty_function,
|
|
902
|
+
num_points: int,
|
|
903
|
+
oversample_ratio: int,
|
|
904
|
+
importance_sample_ratio: float,
|
|
905
|
+
) -> torch.Tensor:
|
|
906
|
+
"""
|
|
907
|
+
This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The
|
|
908
|
+
uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit
|
|
909
|
+
prediction as input.
|
|
910
|
+
|
|
911
|
+
Args:
|
|
912
|
+
logits (`float`):
|
|
913
|
+
Logit predictions for P points.
|
|
914
|
+
uncertainty_function:
|
|
915
|
+
A function that takes logit predictions for P points and returns their uncertainties.
|
|
916
|
+
num_points (`int`):
|
|
917
|
+
The number of points P to sample.
|
|
918
|
+
oversample_ratio (`int`):
|
|
919
|
+
Oversampling parameter.
|
|
920
|
+
importance_sample_ratio (`float`):
|
|
921
|
+
Ratio of points that are sampled via importance sampling.
|
|
922
|
+
|
|
923
|
+
Returns:
|
|
924
|
+
point_coordinates (`torch.Tensor`):
|
|
925
|
+
Coordinates for P sampled points.
|
|
926
|
+
"""
|
|
927
|
+
|
|
928
|
+
num_boxes = logits.shape[0]
|
|
929
|
+
num_points_sampled = int(num_points * oversample_ratio)
|
|
930
|
+
|
|
931
|
+
# Get random point coordinates
|
|
932
|
+
point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
|
|
933
|
+
# Get sampled prediction value for the point coordinates
|
|
934
|
+
point_logits = sample_point(logits, point_coordinates, align_corners=False)
|
|
935
|
+
# Calculate the uncertainties based on the sampled prediction values of the points
|
|
936
|
+
point_uncertainties = uncertainty_function(point_logits)
|
|
937
|
+
|
|
938
|
+
num_uncertain_points = int(importance_sample_ratio * num_points)
|
|
939
|
+
num_random_points = num_points - num_uncertain_points
|
|
940
|
+
|
|
941
|
+
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
|
942
|
+
shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
|
|
943
|
+
idx += shift[:, None]
|
|
944
|
+
point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
|
|
945
|
+
|
|
946
|
+
if num_random_points > 0:
|
|
947
|
+
point_coordinates = torch.cat(
|
|
948
|
+
[point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
|
|
949
|
+
dim=1,
|
|
950
|
+
)
|
|
951
|
+
return point_coordinates
|
|
952
|
+
|
|
953
|
+
def forward(
|
|
954
|
+
self,
|
|
955
|
+
masks_queries_logits: torch.Tensor,
|
|
956
|
+
class_queries_logits: torch.Tensor,
|
|
957
|
+
mask_labels: list[torch.Tensor],
|
|
958
|
+
class_labels: list[torch.Tensor],
|
|
959
|
+
auxiliary_predictions: dict[str, torch.Tensor] | None = None,
|
|
960
|
+
) -> dict[str, torch.Tensor]:
|
|
961
|
+
"""
|
|
962
|
+
This performs the loss computation.
|
|
963
|
+
|
|
964
|
+
Args:
|
|
965
|
+
masks_queries_logits (`torch.Tensor`):
|
|
966
|
+
A tensor of shape `(batch_size, num_queries, height, width)`.
|
|
967
|
+
class_queries_logits (`torch.Tensor`):
|
|
968
|
+
A tensor of shape `(batch_size, num_queries, num_labels)`.
|
|
969
|
+
mask_labels (`torch.Tensor`):
|
|
970
|
+
List of mask labels of shape `(labels, height, width)`.
|
|
971
|
+
class_labels (`list[torch.Tensor]`):
|
|
972
|
+
List of class labels of shape `(labels)`.
|
|
973
|
+
auxiliary_predictions (`dict[str, torch.Tensor]`, *optional*):
|
|
974
|
+
if `use_auxiliary_loss` was set to `true` in [`EomtDinov3Config`], then it contains the logits from
|
|
975
|
+
the inner layers of the EomtDinov3MaskedAttentionDecoder.
|
|
976
|
+
|
|
977
|
+
Returns:
|
|
978
|
+
losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys:
|
|
979
|
+
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
|
|
980
|
+
- **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth
|
|
981
|
+
masks.
|
|
982
|
+
- **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
|
|
983
|
+
masks.
|
|
984
|
+
if `use_auxiliary_loss` was set to `true` in [`EomtDinov3Config`], the dictionary contains additional
|
|
985
|
+
losses for each auxiliary predictions.
|
|
986
|
+
"""
|
|
987
|
+
|
|
988
|
+
# retrieve the matching between the outputs of the last layer and the labels
|
|
989
|
+
indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
|
|
990
|
+
# compute the average number of target masks for normalization purposes
|
|
991
|
+
num_masks = self.get_num_masks(class_labels, device=class_labels[0].device)
|
|
992
|
+
# get all the losses
|
|
993
|
+
losses: dict[str, Tensor] = {
|
|
994
|
+
**self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
|
|
995
|
+
**self.loss_labels(class_queries_logits, class_labels, indices),
|
|
996
|
+
}
|
|
997
|
+
# in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
|
998
|
+
if auxiliary_predictions is not None:
|
|
999
|
+
for idx, aux_outputs in enumerate(auxiliary_predictions):
|
|
1000
|
+
masks_queries_logits = aux_outputs["masks_queries_logits"]
|
|
1001
|
+
class_queries_logits = aux_outputs["class_queries_logits"]
|
|
1002
|
+
loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
|
|
1003
|
+
loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
|
|
1004
|
+
losses.update(loss_dict)
|
|
1005
|
+
|
|
1006
|
+
return losses
|
|
1007
|
+
|
|
1008
|
+
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
|
|
1009
|
+
"""
|
|
1010
|
+
Computes the average number of target masks across the batch, for normalization purposes.
|
|
1011
|
+
"""
|
|
1012
|
+
num_masks = sum(len(classes) for classes in class_labels)
|
|
1013
|
+
num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
|
|
1014
|
+
world_size = 1
|
|
1015
|
+
if is_accelerate_available():
|
|
1016
|
+
if PartialState._shared_state != {}:
|
|
1017
|
+
num_masks = reduce(num_masks)
|
|
1018
|
+
world_size = PartialState().num_processes
|
|
1019
|
+
|
|
1020
|
+
num_masks = torch.clamp(num_masks / world_size, min=1)
|
|
1021
|
+
return num_masks
|
|
1022
|
+
|
|
1023
|
+
|
|
1024
|
+
@dataclass
|
|
1025
|
+
@auto_docstring(
|
|
1026
|
+
custom_intro="""
|
|
1027
|
+
Class for outputs of [`EomtDinov3ForUniversalSegmentationOutput`].
|
|
1028
|
+
|
|
1029
|
+
This output can be directly passed to [`~EomtDinov3ImageProcessor.post_process_semantic_segmentation`] or
|
|
1030
|
+
[`~EomtDinov3ImageProcessor.post_process_instance_segmentation`] or
|
|
1031
|
+
[`~EomtDinov3ImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
|
|
1032
|
+
[`~EomtDinov3ImageProcessor] for details regarding usage.
|
|
1033
|
+
"""
|
|
1034
|
+
)
|
|
1035
|
+
class EomtDinov3ForUniversalSegmentationOutput(ModelOutput):
|
|
1036
|
+
r"""
|
|
1037
|
+
loss (`torch.Tensor`, *optional*):
|
|
1038
|
+
The computed loss, returned when labels are present.
|
|
1039
|
+
class_queries_logits (`torch.FloatTensor`):
|
|
1040
|
+
A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
|
|
1041
|
+
query. Note the `+ 1` is needed because we incorporate the null class.
|
|
1042
|
+
masks_queries_logits (`torch.FloatTensor`):
|
|
1043
|
+
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
|
|
1044
|
+
query.
|
|
1045
|
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
1046
|
+
Last hidden states (final feature map) of the last layer.
|
|
1047
|
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
|
1048
|
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
|
1049
|
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model.
|
|
1050
|
+
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
|
1051
|
+
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
1052
|
+
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
|
|
1053
|
+
patch_offsets (`list[torch.Tensor]`, *optional*):
|
|
1054
|
+
list of tuples indicating the image index and start and end positions of patches for semantic segmentation.
|
|
1055
|
+
"""
|
|
1056
|
+
|
|
1057
|
+
loss: torch.FloatTensor | None = None
|
|
1058
|
+
class_queries_logits: torch.FloatTensor | None = None
|
|
1059
|
+
masks_queries_logits: torch.FloatTensor | None = None
|
|
1060
|
+
last_hidden_state: torch.FloatTensor | None = None
|
|
1061
|
+
hidden_states: tuple[torch.FloatTensor] | None = None
|
|
1062
|
+
attentions: tuple[torch.FloatTensor] | None = None
|
|
1063
|
+
patch_offsets: list[torch.Tensor] | None = None
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
@auto_docstring
|
|
1067
|
+
class EomtDinov3PreTrainedModel(PreTrainedModel):
|
|
1068
|
+
"""
|
|
1069
|
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
1070
|
+
models.
|
|
1071
|
+
"""
|
|
1072
|
+
|
|
1073
|
+
config: EomtDinov3Config
|
|
1074
|
+
base_model_prefix = "eomt_dinov3"
|
|
1075
|
+
main_input_name = "pixel_values"
|
|
1076
|
+
input_modalities = ("image",)
|
|
1077
|
+
supports_gradient_checkpointing = False
|
|
1078
|
+
_no_split_modules = ["EomtDinov3Layer"]
|
|
1079
|
+
_supports_sdpa = True
|
|
1080
|
+
_can_record_outputs = {
|
|
1081
|
+
"hidden_states": EomtDinov3Layer,
|
|
1082
|
+
"attentions": EomtDinov3Attention,
|
|
1083
|
+
}
|
|
1084
|
+
config_class = EomtDinov3Config
|
|
1085
|
+
|
|
1086
|
+
@torch.no_grad()
|
|
1087
|
+
def _init_weights(self, module: nn.Module) -> None:
|
|
1088
|
+
super()._init_weights(module)
|
|
1089
|
+
std = self.config.initializer_range
|
|
1090
|
+
if isinstance(module, EomtDinov3LayerScale):
|
|
1091
|
+
if hasattr(module, "lambda1"):
|
|
1092
|
+
init.constant_(module.lambda1, self.config.layerscale_value)
|
|
1093
|
+
elif isinstance(module, EomtDinov3Embeddings):
|
|
1094
|
+
init.trunc_normal_(module.cls_token, mean=0.0, std=std)
|
|
1095
|
+
init.zeros_(module.register_tokens)
|
|
1096
|
+
elif isinstance(module, EomtDinov3Loss):
|
|
1097
|
+
empty_weight = torch.ones(module.num_labels + 1)
|
|
1098
|
+
empty_weight[-1] = module.eos_coef
|
|
1099
|
+
init.copy_(module.empty_weight, empty_weight)
|
|
1100
|
+
elif isinstance(module, EomtDinov3ForUniversalSegmentation):
|
|
1101
|
+
init.ones_(module.attn_mask_probs)
|
|
1102
|
+
|
|
1103
|
+
|
|
1104
|
+
class EomtDinov3LayerNorm2d(nn.LayerNorm):
|
|
1105
|
+
def __init__(self, num_channels, eps=1e-6, affine=True):
|
|
1106
|
+
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
|
1107
|
+
|
|
1108
|
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
|
1109
|
+
hidden_state = hidden_state.permute(0, 2, 3, 1)
|
|
1110
|
+
hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
1111
|
+
hidden_state = hidden_state.permute(0, 3, 1, 2)
|
|
1112
|
+
return hidden_state
|
|
1113
|
+
|
|
1114
|
+
|
|
1115
|
+
class EomtDinov3ScaleLayer(nn.Module):
|
|
1116
|
+
def __init__(self, config: EomtDinov3Config):
|
|
1117
|
+
super().__init__()
|
|
1118
|
+
hidden_size = config.hidden_size
|
|
1119
|
+
self.conv1 = nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2)
|
|
1120
|
+
self.activation = ACT2FN[config.hidden_act]
|
|
1121
|
+
self.conv2 = nn.Conv2d(
|
|
1122
|
+
hidden_size,
|
|
1123
|
+
hidden_size,
|
|
1124
|
+
kernel_size=3,
|
|
1125
|
+
padding=1,
|
|
1126
|
+
groups=hidden_size,
|
|
1127
|
+
bias=False,
|
|
1128
|
+
)
|
|
1129
|
+
|
|
1130
|
+
self.layernorm2d = EomtDinov3LayerNorm2d(hidden_size)
|
|
1131
|
+
|
|
1132
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
1133
|
+
hidden_states = self.conv1(hidden_states)
|
|
1134
|
+
hidden_states = self.activation(hidden_states)
|
|
1135
|
+
hidden_states = self.conv2(hidden_states)
|
|
1136
|
+
hidden_states = self.layernorm2d(hidden_states)
|
|
1137
|
+
return hidden_states
|
|
1138
|
+
|
|
1139
|
+
|
|
1140
|
+
class EomtDinov3ScaleBlock(nn.Module):
|
|
1141
|
+
def __init__(self, config: EomtDinov3Config):
|
|
1142
|
+
super().__init__()
|
|
1143
|
+
self.num_blocks = config.num_upscale_blocks
|
|
1144
|
+
self.block = nn.ModuleList([EomtDinov3ScaleLayer(config) for _ in range(self.num_blocks)])
|
|
1145
|
+
|
|
1146
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
1147
|
+
for block in self.block:
|
|
1148
|
+
hidden_states = block(hidden_states)
|
|
1149
|
+
return hidden_states
|
|
1150
|
+
|
|
1151
|
+
|
|
1152
|
+
class EomtDinov3MaskHead(nn.Module):
|
|
1153
|
+
def __init__(self, config: EomtDinov3Config):
|
|
1154
|
+
super().__init__()
|
|
1155
|
+
|
|
1156
|
+
hidden_size = config.hidden_size
|
|
1157
|
+
self.fc1 = nn.Linear(hidden_size, hidden_size)
|
|
1158
|
+
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
|
1159
|
+
self.fc3 = nn.Linear(hidden_size, hidden_size)
|
|
1160
|
+
self.activation = ACT2FN[config.hidden_act]
|
|
1161
|
+
|
|
1162
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
1163
|
+
hidden_states = self.activation(self.fc1(hidden_states))
|
|
1164
|
+
hidden_states = self.activation(self.fc2(hidden_states))
|
|
1165
|
+
hidden_states = self.fc3(hidden_states)
|
|
1166
|
+
return hidden_states
|
|
1167
|
+
|
|
1168
|
+
|
|
1169
|
+
@auto_docstring(
|
|
1170
|
+
custom_intro="""
|
|
1171
|
+
The EoMT-DINOv3 model with head on top for instance/semantic/panoptic segmentation.
|
|
1172
|
+
""",
|
|
1173
|
+
)
|
|
1174
|
+
class EomtDinov3ForUniversalSegmentation(EomtDinov3PreTrainedModel):
|
|
1175
|
+
main_input_name = "pixel_values"
|
|
1176
|
+
|
|
1177
|
+
def __init__(self, config: EomtDinov3Config):
|
|
1178
|
+
super().__init__(config)
|
|
1179
|
+
self.config = config
|
|
1180
|
+
self.num_hidden_layers = config.num_hidden_layers
|
|
1181
|
+
self.embeddings = EomtDinov3Embeddings(config)
|
|
1182
|
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
1183
|
+
|
|
1184
|
+
self.query = nn.Embedding(config.num_queries, config.hidden_size)
|
|
1185
|
+
self.layers = nn.ModuleList([EomtDinov3Layer(config) for _ in range(config.num_hidden_layers)])
|
|
1186
|
+
|
|
1187
|
+
self.upscale_block = EomtDinov3ScaleBlock(config)
|
|
1188
|
+
self.mask_head = EomtDinov3MaskHead(config)
|
|
1189
|
+
|
|
1190
|
+
self.class_predictor = nn.Linear(config.hidden_size, config.num_labels + 1)
|
|
1191
|
+
|
|
1192
|
+
self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
|
|
1193
|
+
self.weight_dict: dict[str, float] = {
|
|
1194
|
+
"loss_cross_entropy": config.class_weight,
|
|
1195
|
+
"loss_mask": config.mask_weight,
|
|
1196
|
+
"loss_dice": config.dice_weight,
|
|
1197
|
+
}
|
|
1198
|
+
|
|
1199
|
+
self.criterion = EomtDinov3Loss(config=config, weight_dict=self.weight_dict)
|
|
1200
|
+
|
|
1201
|
+
self.register_buffer("attn_mask_probs", torch.ones(config.num_blocks))
|
|
1202
|
+
|
|
1203
|
+
self.num_prefix_tokens = 1 + config.num_register_tokens
|
|
1204
|
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
1205
|
+
self.embeddings.register_parameter("mask_token", None)
|
|
1206
|
+
|
|
1207
|
+
self.rope_embeddings = EomtDinov3RotaryEmbedding(config)
|
|
1208
|
+
|
|
1209
|
+
self.post_init()
|
|
1210
|
+
|
|
1211
|
+
def get_loss_dict(
|
|
1212
|
+
self,
|
|
1213
|
+
masks_queries_logits: Tensor,
|
|
1214
|
+
class_queries_logits: Tensor,
|
|
1215
|
+
mask_labels: Tensor,
|
|
1216
|
+
class_labels: Tensor,
|
|
1217
|
+
auxiliary_predictions: dict[str, Tensor],
|
|
1218
|
+
) -> dict[str, Tensor]:
|
|
1219
|
+
loss_dict: dict[str, Tensor] = self.criterion(
|
|
1220
|
+
masks_queries_logits=masks_queries_logits,
|
|
1221
|
+
class_queries_logits=class_queries_logits,
|
|
1222
|
+
mask_labels=mask_labels,
|
|
1223
|
+
class_labels=class_labels,
|
|
1224
|
+
auxiliary_predictions=auxiliary_predictions,
|
|
1225
|
+
)
|
|
1226
|
+
|
|
1227
|
+
# weight each loss by `self.weight_dict[<LOSS_NAME>]` including auxiliary losses
|
|
1228
|
+
for key, weight in self.weight_dict.items():
|
|
1229
|
+
for loss_key, loss in loss_dict.items():
|
|
1230
|
+
if key in loss_key:
|
|
1231
|
+
loss *= weight
|
|
1232
|
+
|
|
1233
|
+
return loss_dict
|
|
1234
|
+
|
|
1235
|
+
def get_loss(self, loss_dict: dict[str, Tensor]) -> Tensor:
|
|
1236
|
+
return sum(loss_dict.values())
|
|
1237
|
+
|
|
1238
|
+
@check_model_inputs
|
|
1239
|
+
@auto_docstring
|
|
1240
|
+
def forward(
|
|
1241
|
+
self,
|
|
1242
|
+
pixel_values: Tensor,
|
|
1243
|
+
mask_labels: list[Tensor] | None = None,
|
|
1244
|
+
class_labels: list[Tensor] | None = None,
|
|
1245
|
+
patch_offsets: list[Tensor] | None = None,
|
|
1246
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1247
|
+
) -> EomtDinov3ForUniversalSegmentationOutput:
|
|
1248
|
+
r"""
|
|
1249
|
+
mask_labels (`list[torch.Tensor]`, *optional*):
|
|
1250
|
+
list of mask labels of shape `(num_labels, height, width)` to be fed to a model
|
|
1251
|
+
class_labels (`list[torch.LongTensor]`, *optional*):
|
|
1252
|
+
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
|
|
1253
|
+
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
|
|
1254
|
+
patch_offsets (`list[torch.Tensor]`, *optional*):
|
|
1255
|
+
list of tuples indicating the image index and start and end positions of patches for semantic segmentation.
|
|
1256
|
+
"""
|
|
1257
|
+
masks_queries_logits_per_layer, class_queries_logits_per_layer = (), ()
|
|
1258
|
+
|
|
1259
|
+
hidden_states = self.dropout(self.embeddings(pixel_values))
|
|
1260
|
+
position_embeddings = self.rope_embeddings(pixel_values.to(hidden_states.dtype))
|
|
1261
|
+
attention_mask = None
|
|
1262
|
+
|
|
1263
|
+
for idx, layer_module in enumerate(self.layers):
|
|
1264
|
+
if idx == self.num_hidden_layers - self.config.num_blocks:
|
|
1265
|
+
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
|
|
1266
|
+
hidden_states = torch.cat((query, hidden_states), dim=1)
|
|
1267
|
+
|
|
1268
|
+
if idx >= self.num_hidden_layers - self.config.num_blocks and (
|
|
1269
|
+
self.training or self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks] > 0
|
|
1270
|
+
):
|
|
1271
|
+
norm_hidden_states = self.layernorm(hidden_states)
|
|
1272
|
+
masks_queries_logits, class_queries_logits = self.predict(norm_hidden_states)
|
|
1273
|
+
|
|
1274
|
+
masks_queries_logits_per_layer += (masks_queries_logits,)
|
|
1275
|
+
class_queries_logits_per_layer += (class_queries_logits,)
|
|
1276
|
+
|
|
1277
|
+
attention_mask = torch.ones(
|
|
1278
|
+
hidden_states.shape[0],
|
|
1279
|
+
hidden_states.shape[1],
|
|
1280
|
+
hidden_states.shape[1],
|
|
1281
|
+
device=hidden_states.device,
|
|
1282
|
+
dtype=torch.bool,
|
|
1283
|
+
)
|
|
1284
|
+
|
|
1285
|
+
interpolated_logits = F.interpolate(masks_queries_logits, size=self.grid_size, mode="bilinear")
|
|
1286
|
+
interpolated_logits = interpolated_logits.view(
|
|
1287
|
+
interpolated_logits.size(0), interpolated_logits.size(1), -1
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1290
|
+
num_query_tokens = self.config.num_queries
|
|
1291
|
+
encoder_start_tokens = num_query_tokens + self.num_prefix_tokens
|
|
1292
|
+
|
|
1293
|
+
# Set attention mask for queries to focus on encoder tokens based on interpolated logits
|
|
1294
|
+
attention_mask[:, :num_query_tokens, encoder_start_tokens:] = interpolated_logits > 0
|
|
1295
|
+
|
|
1296
|
+
# Disable attention mask for random query tokens.
|
|
1297
|
+
attention_mask = self._disable_attention_mask(
|
|
1298
|
+
attention_mask,
|
|
1299
|
+
prob=self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks],
|
|
1300
|
+
num_query_tokens=num_query_tokens,
|
|
1301
|
+
encoder_start_tokens=encoder_start_tokens,
|
|
1302
|
+
device=attention_mask.device,
|
|
1303
|
+
)
|
|
1304
|
+
|
|
1305
|
+
# Expand attention mask to 4d mask.
|
|
1306
|
+
attention_mask = attention_mask[:, None, ...].expand(-1, self.config.num_attention_heads, -1, -1)
|
|
1307
|
+
dtype_min = torch.finfo(hidden_states.dtype).min
|
|
1308
|
+
attention_mask = attention_mask.to(hidden_states.dtype).masked_fill(~attention_mask, dtype_min)
|
|
1309
|
+
|
|
1310
|
+
hidden_states = layer_module(
|
|
1311
|
+
hidden_states,
|
|
1312
|
+
attention_mask=attention_mask,
|
|
1313
|
+
position_embeddings=position_embeddings,
|
|
1314
|
+
)
|
|
1315
|
+
|
|
1316
|
+
sequence_output = self.layernorm(hidden_states)
|
|
1317
|
+
|
|
1318
|
+
masks_queries_logits, class_queries_logits = self.predict(sequence_output)
|
|
1319
|
+
masks_queries_logits_per_layer += (masks_queries_logits,)
|
|
1320
|
+
class_queries_logits_per_layer += (class_queries_logits,)
|
|
1321
|
+
|
|
1322
|
+
loss = None
|
|
1323
|
+
if mask_labels is not None and class_labels is not None:
|
|
1324
|
+
loss = 0.0
|
|
1325
|
+
for masks_queries_logits, class_queries_logits in zip(
|
|
1326
|
+
masks_queries_logits_per_layer, class_queries_logits_per_layer
|
|
1327
|
+
):
|
|
1328
|
+
loss_dict = self.get_loss_dict(
|
|
1329
|
+
masks_queries_logits=masks_queries_logits,
|
|
1330
|
+
class_queries_logits=class_queries_logits,
|
|
1331
|
+
mask_labels=mask_labels,
|
|
1332
|
+
class_labels=class_labels,
|
|
1333
|
+
auxiliary_predictions=None,
|
|
1334
|
+
)
|
|
1335
|
+
loss += self.get_loss(loss_dict)
|
|
1336
|
+
|
|
1337
|
+
return EomtDinov3ForUniversalSegmentationOutput(
|
|
1338
|
+
loss=loss,
|
|
1339
|
+
masks_queries_logits=masks_queries_logits,
|
|
1340
|
+
class_queries_logits=class_queries_logits,
|
|
1341
|
+
last_hidden_state=sequence_output,
|
|
1342
|
+
patch_offsets=patch_offsets,
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
def get_input_embeddings(self):
|
|
1346
|
+
return self.embeddings.patch_embeddings
|
|
1347
|
+
|
|
1348
|
+
def predict(self, logits: torch.Tensor):
|
|
1349
|
+
query_tokens = logits[:, : self.config.num_queries, :]
|
|
1350
|
+
class_logits = self.class_predictor(query_tokens)
|
|
1351
|
+
|
|
1352
|
+
prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :]
|
|
1353
|
+
prefix_tokens = prefix_tokens.transpose(1, 2)
|
|
1354
|
+
|
|
1355
|
+
prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size)
|
|
1356
|
+
|
|
1357
|
+
query_tokens = self.mask_head(query_tokens)
|
|
1358
|
+
prefix_tokens = self.upscale_block(prefix_tokens)
|
|
1359
|
+
|
|
1360
|
+
mask_logits = torch.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens)
|
|
1361
|
+
|
|
1362
|
+
return mask_logits, class_logits
|
|
1363
|
+
|
|
1364
|
+
@staticmethod
|
|
1365
|
+
def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device):
|
|
1366
|
+
if prob < 1:
|
|
1367
|
+
# Generate random queries to disable based on the probs
|
|
1368
|
+
random_queries = torch.rand(attn_mask.shape[0], num_query_tokens, device=device) > prob
|
|
1369
|
+
|
|
1370
|
+
# Disable attention to the query tokens, considering the prefix tokens
|
|
1371
|
+
attn_mask[:, :num_query_tokens, encoder_start_tokens:][random_queries] = 1
|
|
1372
|
+
|
|
1373
|
+
return attn_mask
|
|
1374
|
+
|
|
1375
|
+
|
|
1376
|
+
__all__ = ["EomtDinov3PreTrainedModel", "EomtDinov3ForUniversalSegmentation"]
|