transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -39,6 +39,10 @@ class DPRContextEncoderTokenizer(BertTokenizer):
|
|
|
39
39
|
|
|
40
40
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
41
41
|
|
|
42
|
+
def __init__(self, *args, do_lower_case=False, **kwargs):
|
|
43
|
+
super().__init__(*args, **kwargs)
|
|
44
|
+
self.do_lower_case = do_lower_case
|
|
45
|
+
|
|
42
46
|
|
|
43
47
|
class DPRQuestionEncoderTokenizer(BertTokenizer):
|
|
44
48
|
r"""
|
|
@@ -52,6 +56,10 @@ class DPRQuestionEncoderTokenizer(BertTokenizer):
|
|
|
52
56
|
|
|
53
57
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
54
58
|
|
|
59
|
+
def __init__(self, *args, do_lower_case=False, **kwargs):
|
|
60
|
+
super().__init__(*args, **kwargs)
|
|
61
|
+
self.do_lower_case = do_lower_case
|
|
62
|
+
|
|
55
63
|
|
|
56
64
|
DPRSpanPrediction = collections.namedtuple(
|
|
57
65
|
"DPRSpanPrediction", ["span_score", "relevance_score", "doc_id", "start_index", "end_index", "text"]
|
|
@@ -316,5 +324,9 @@ class DPRReaderTokenizer(CustomDPRReaderTokenizerMixin, BertTokenizer):
|
|
|
316
324
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
317
325
|
model_input_names = ["input_ids", "attention_mask"]
|
|
318
326
|
|
|
327
|
+
def __init__(self, *args, do_lower_case=False, **kwargs):
|
|
328
|
+
super().__init__(*args, **kwargs)
|
|
329
|
+
self.do_lower_case = do_lower_case
|
|
330
|
+
|
|
319
331
|
|
|
320
332
|
__all__ = ["DPRContextEncoderTokenizer", "DPRQuestionEncoderTokenizer", "DPRReaderOutput", "DPRReaderTokenizer"]
|
|
@@ -102,7 +102,7 @@ class DPTConfig(PreTrainedConfig):
|
|
|
102
102
|
Used only for the `hybrid` embedding type. The shape of the feature maps of the backbone.
|
|
103
103
|
neck_ignore_stages (`list[int]`, *optional*, defaults to `[0, 1]`):
|
|
104
104
|
Used only for the `hybrid` embedding type. The stages of the readout layers to ignore.
|
|
105
|
-
backbone_config (`Union[dict
|
|
105
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `BitConfig()`):
|
|
106
106
|
The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
|
|
107
107
|
leverage the [`AutoBackbone`] API.
|
|
108
108
|
backbone (`str`, *optional*):
|
|
@@ -225,8 +225,7 @@ class DPTImageProcessorFast(BaseImageProcessorFast):
|
|
|
225
225
|
processed_images_grouped[shape] = stacked_images
|
|
226
226
|
|
|
227
227
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
228
|
-
|
|
229
|
-
return BatchFeature(data={"pixel_values": processed_images})
|
|
228
|
+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
230
229
|
|
|
231
230
|
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
|
|
232
231
|
"""
|
|
@@ -228,8 +228,7 @@ class DPTImageProcessorFast(BeitImageProcessorFast):
|
|
|
228
228
|
processed_images_grouped[shape] = stacked_images
|
|
229
229
|
|
|
230
230
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
231
|
-
|
|
232
|
-
return BatchFeature(data={"pixel_values": processed_images})
|
|
231
|
+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
233
232
|
|
|
234
233
|
def post_process_depth_estimation(
|
|
235
234
|
self,
|
|
@@ -33,7 +33,7 @@ class EdgeTamVisionConfig(PreTrainedConfig):
|
|
|
33
33
|
documentation from [`PreTrainedConfig`] for more information.
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
|
-
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional
|
|
36
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `timm/repvit_m1.dist_in1k`):
|
|
37
37
|
Configuration for the vision backbone. This is used to instantiate the backbone using
|
|
38
38
|
`AutoModel.from_config`.
|
|
39
39
|
backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`):
|
|
@@ -30,7 +30,7 @@ import torch.nn as nn
|
|
|
30
30
|
import torch.nn.functional as F
|
|
31
31
|
from torch import Tensor
|
|
32
32
|
|
|
33
|
-
from transformers.utils.generic import OutputRecorder
|
|
33
|
+
from transformers.utils.generic import OutputRecorder
|
|
34
34
|
|
|
35
35
|
from ... import initialization as init
|
|
36
36
|
from ...activations import ACT2FN
|
|
@@ -39,6 +39,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
40
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
41
41
|
from ...utils import ModelOutput, auto_docstring
|
|
42
|
+
from ...utils.generic import TransformersKwargs, check_model_inputs
|
|
42
43
|
from ..auto import AutoModel
|
|
43
44
|
from .configuration_edgetam import (
|
|
44
45
|
EdgeTamConfig,
|
|
@@ -50,7 +51,7 @@ from .configuration_edgetam import (
|
|
|
50
51
|
|
|
51
52
|
# fix this in modular
|
|
52
53
|
if True:
|
|
53
|
-
from
|
|
54
|
+
from ..timm_wrapper.modeling_timm_wrapper import TimmWrapperModel
|
|
54
55
|
|
|
55
56
|
|
|
56
57
|
class EdgeTamLayerNorm(nn.LayerNorm):
|
|
@@ -315,6 +316,8 @@ class EdgeTamPreTrainedModel(PreTrainedModel):
|
|
|
315
316
|
if isinstance(module, EdgeTamModel):
|
|
316
317
|
if module.no_memory_embedding is not None:
|
|
317
318
|
init.zeros_(module.no_memory_embedding)
|
|
319
|
+
elif hasattr(module, "positional_embedding"):
|
|
320
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
318
321
|
|
|
319
322
|
|
|
320
323
|
# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
|
|
@@ -393,7 +396,7 @@ class EdgeTamVisionNeck(nn.Module):
|
|
|
393
396
|
n = len(self.convs) - 1
|
|
394
397
|
for i in range(n, -1, -1):
|
|
395
398
|
lateral_features = hidden_states[i].permute(0, 3, 1, 2)
|
|
396
|
-
lateral_features = self.convs[n - i](lateral_features)
|
|
399
|
+
lateral_features = self.convs[n - i](lateral_features.to(self.convs[i].weight.dtype))
|
|
397
400
|
if i not in self.fpn_top_down_levels or i == n:
|
|
398
401
|
prev_features = lateral_features
|
|
399
402
|
else:
|
|
@@ -19,8 +19,17 @@ from typing import Optional, Union
|
|
|
19
19
|
import torch
|
|
20
20
|
import torch.utils.checkpoint
|
|
21
21
|
|
|
22
|
-
from
|
|
23
|
-
from
|
|
22
|
+
from ... import initialization as init
|
|
23
|
+
from ...configuration_utils import PreTrainedConfig
|
|
24
|
+
from ...modeling_utils import PreTrainedModel
|
|
25
|
+
from ...processing_utils import Unpack
|
|
26
|
+
from ...utils import (
|
|
27
|
+
auto_docstring,
|
|
28
|
+
)
|
|
29
|
+
from ...utils.generic import TransformersKwargs, check_model_inputs
|
|
30
|
+
from ..auto import CONFIG_MAPPING, AutoConfig
|
|
31
|
+
from ..sam2.configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig
|
|
32
|
+
from ..sam2.modeling_sam2 import (
|
|
24
33
|
Sam2Attention,
|
|
25
34
|
Sam2FeedForward,
|
|
26
35
|
Sam2LayerNorm,
|
|
@@ -30,21 +39,11 @@ from transformers.models.sam2.modeling_sam2 import (
|
|
|
30
39
|
Sam2VisionEncoderOutput,
|
|
31
40
|
Sam2VisionModel,
|
|
32
41
|
)
|
|
33
|
-
from transformers.utils.generic import TransformersKwargs, check_model_inputs
|
|
34
|
-
|
|
35
|
-
from ... import initialization as init
|
|
36
|
-
from ...configuration_utils import PreTrainedConfig
|
|
37
|
-
from ...modeling_utils import PreTrainedModel
|
|
38
|
-
from ...processing_utils import Unpack
|
|
39
|
-
from ...utils import (
|
|
40
|
-
auto_docstring,
|
|
41
|
-
)
|
|
42
|
-
from ..auto import CONFIG_MAPPING, AutoConfig
|
|
43
42
|
|
|
44
43
|
|
|
45
44
|
# fix this in modular
|
|
46
45
|
if True:
|
|
47
|
-
from
|
|
46
|
+
from ..timm_wrapper.modeling_timm_wrapper import TimmWrapperModel
|
|
48
47
|
|
|
49
48
|
|
|
50
49
|
class EdgeTamVisionConfig(PreTrainedConfig):
|
|
@@ -58,7 +57,7 @@ class EdgeTamVisionConfig(PreTrainedConfig):
|
|
|
58
57
|
documentation from [`PreTrainedConfig`] for more information.
|
|
59
58
|
|
|
60
59
|
Args:
|
|
61
|
-
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional
|
|
60
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `timm/repvit_m1.dist_in1k`):
|
|
62
61
|
Configuration for the vision backbone. This is used to instantiate the backbone using
|
|
63
62
|
`AutoModel.from_config`.
|
|
64
63
|
backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`):
|
|
@@ -181,6 +180,8 @@ class EdgeTamPreTrainedModel(Sam2PreTrainedModel):
|
|
|
181
180
|
if isinstance(module, EdgeTamModel):
|
|
182
181
|
if module.no_memory_embedding is not None:
|
|
183
182
|
init.zeros_(module.no_memory_embedding)
|
|
183
|
+
elif hasattr(module, "positional_embedding"):
|
|
184
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
184
185
|
|
|
185
186
|
|
|
186
187
|
@auto_docstring(
|
|
@@ -152,24 +152,17 @@ class EdgeTamVideoVisionRotaryEmbedding(nn.Module):
|
|
|
152
152
|
|
|
153
153
|
def __init__(self, config: EdgeTamVideoConfig, end_x: Optional[int] = None, end_y: Optional[int] = None):
|
|
154
154
|
super().__init__()
|
|
155
|
-
dim = config.memory_attention_hidden_size // (
|
|
155
|
+
self.dim = config.memory_attention_hidden_size // (
|
|
156
156
|
config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
|
|
157
157
|
)
|
|
158
158
|
# Ensure even dimension for proper axial splitting
|
|
159
|
-
if dim % 4 != 0:
|
|
159
|
+
if self.dim % 4 != 0:
|
|
160
160
|
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
|
|
161
|
-
end_x, end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
|
|
162
|
-
|
|
161
|
+
self.end_x, self.end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
|
|
162
|
+
self.memory_attention_rope_theta = config.memory_attention_rope_theta
|
|
163
163
|
|
|
164
|
-
# Generate 2D position indices for axial rotary embedding
|
|
165
|
-
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
166
|
-
x_positions = flattened_indices % end_x
|
|
167
|
-
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
|
|
168
|
-
freqs_x = torch.outer(x_positions, freqs).float()
|
|
169
|
-
freqs_y = torch.outer(y_positions, freqs).float()
|
|
170
|
-
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
171
|
-
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
172
164
|
# directly register the cos and sin embeddings as we have a fixed feature shape
|
|
165
|
+
inv_freq = self.create_inv_freq()
|
|
173
166
|
self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
|
|
174
167
|
self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
|
|
175
168
|
|
|
@@ -178,6 +171,20 @@ class EdgeTamVideoVisionRotaryEmbedding(nn.Module):
|
|
|
178
171
|
# As the feature map size is fixed, we can just return the pre-computed embeddings.
|
|
179
172
|
return self.rope_embeddings_cos, self.rope_embeddings_sin
|
|
180
173
|
|
|
174
|
+
def create_inv_freq(self):
|
|
175
|
+
freqs = 1.0 / (
|
|
176
|
+
self.memory_attention_rope_theta ** (torch.arange(0, self.dim, 4)[: (self.dim // 4)].float() / self.dim)
|
|
177
|
+
)
|
|
178
|
+
# Generate 2D position indices for axial rotary embedding
|
|
179
|
+
flattened_indices = torch.arange(self.end_x * self.end_y, dtype=torch.long)
|
|
180
|
+
x_positions = flattened_indices % self.end_x
|
|
181
|
+
y_positions = torch.div(flattened_indices, self.end_x, rounding_mode="floor")
|
|
182
|
+
freqs_x = torch.outer(x_positions, freqs).float()
|
|
183
|
+
freqs_y = torch.outer(y_positions, freqs).float()
|
|
184
|
+
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
185
|
+
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
186
|
+
return inv_freq
|
|
187
|
+
|
|
181
188
|
|
|
182
189
|
def eager_attention_forward(
|
|
183
190
|
module: nn.Module,
|
|
@@ -769,6 +776,31 @@ class EdgeTamVideoFeedForward(nn.Module):
|
|
|
769
776
|
return hidden_states
|
|
770
777
|
|
|
771
778
|
|
|
779
|
+
class EdgeTamVideoPositionalEmbedding(nn.Module):
|
|
780
|
+
def __init__(self, config: EdgeTamVideoPromptEncoderConfig):
|
|
781
|
+
super().__init__()
|
|
782
|
+
self.scale = config.scale
|
|
783
|
+
positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
|
|
784
|
+
self.register_buffer("positional_embedding", positional_embedding)
|
|
785
|
+
|
|
786
|
+
def forward(self, input_coords, input_shape=None):
|
|
787
|
+
"""Positionally encode points that are normalized to [0,1]."""
|
|
788
|
+
coordinates = input_coords.clone()
|
|
789
|
+
|
|
790
|
+
if input_shape is not None:
|
|
791
|
+
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
|
|
792
|
+
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
|
|
793
|
+
coordinates.to(torch.float32)
|
|
794
|
+
|
|
795
|
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
|
796
|
+
coordinates = 2 * coordinates - 1
|
|
797
|
+
coordinates = coordinates.to(self.positional_embedding.dtype)
|
|
798
|
+
coordinates = coordinates @ self.positional_embedding
|
|
799
|
+
coordinates = 2 * np.pi * coordinates
|
|
800
|
+
# outputs d_1 x ... x d_n x channel shape
|
|
801
|
+
return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
|
|
802
|
+
|
|
803
|
+
|
|
772
804
|
@auto_docstring
|
|
773
805
|
class EdgeTamVideoPreTrainedModel(PreTrainedModel):
|
|
774
806
|
config_class = EdgeTamVideoConfig
|
|
@@ -794,6 +826,16 @@ class EdgeTamVideoPreTrainedModel(PreTrainedModel):
|
|
|
794
826
|
if isinstance(module, EdgeTamVideoMemoryFuserCXBlock):
|
|
795
827
|
if module.scale is not None:
|
|
796
828
|
init.zeros_(module.scale)
|
|
829
|
+
elif isinstance(module, EdgeTamVideoVisionRotaryEmbedding):
|
|
830
|
+
inv_freq = module.create_inv_freq()
|
|
831
|
+
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
|
|
832
|
+
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
|
|
833
|
+
elif isinstance(module, EdgeTamVideoPositionalEmbedding):
|
|
834
|
+
init.normal_(module.positional_embedding, std=module.scale)
|
|
835
|
+
if isinstance(module, EdgeTamVideoVisionRotaryEmbedding):
|
|
836
|
+
inv_freq = module.create_inv_freq()
|
|
837
|
+
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
|
|
838
|
+
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
|
|
797
839
|
|
|
798
840
|
|
|
799
841
|
class EdgeTamVideoInferenceCache:
|
|
@@ -959,7 +1001,7 @@ class EdgeTamVideoInferenceSession:
|
|
|
959
1001
|
device_inputs = {}
|
|
960
1002
|
for key, value in inputs.items():
|
|
961
1003
|
if isinstance(value, torch.Tensor):
|
|
962
|
-
device_inputs[key] = value.to(self.inference_device, non_blocking=
|
|
1004
|
+
device_inputs[key] = value.to(self.inference_device, non_blocking=False)
|
|
963
1005
|
else:
|
|
964
1006
|
device_inputs[key] = value
|
|
965
1007
|
self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
|
|
@@ -1547,31 +1589,6 @@ class EdgeTamVideoSegmentationOutput(ModelOutput):
|
|
|
1547
1589
|
frame_idx: Optional[int] = None
|
|
1548
1590
|
|
|
1549
1591
|
|
|
1550
|
-
class EdgeTamVideoPositionalEmbedding(nn.Module):
|
|
1551
|
-
def __init__(self, config: EdgeTamVideoPromptEncoderConfig):
|
|
1552
|
-
super().__init__()
|
|
1553
|
-
self.scale = config.scale
|
|
1554
|
-
positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
|
|
1555
|
-
self.register_buffer("positional_embedding", positional_embedding)
|
|
1556
|
-
|
|
1557
|
-
def forward(self, input_coords, input_shape=None):
|
|
1558
|
-
"""Positionally encode points that are normalized to [0,1]."""
|
|
1559
|
-
coordinates = input_coords.clone()
|
|
1560
|
-
|
|
1561
|
-
if input_shape is not None:
|
|
1562
|
-
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
|
|
1563
|
-
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
|
|
1564
|
-
coordinates.to(torch.float32)
|
|
1565
|
-
|
|
1566
|
-
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
|
1567
|
-
coordinates = 2 * coordinates - 1
|
|
1568
|
-
coordinates = coordinates.to(self.positional_embedding.dtype)
|
|
1569
|
-
coordinates = coordinates @ self.positional_embedding
|
|
1570
|
-
coordinates = 2 * np.pi * coordinates
|
|
1571
|
-
# outputs d_1 x ... x d_n x channel shape
|
|
1572
|
-
return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
1592
|
class EdgeTamVideoMaskEmbedding(nn.Module):
|
|
1576
1593
|
def __init__(self, config: EdgeTamVideoPromptEncoderConfig):
|
|
1577
1594
|
super().__init__()
|
|
@@ -1976,11 +1993,6 @@ class EdgeTamVideoModel(EdgeTamVideoPreTrainedModel):
|
|
|
1976
1993
|
input_modalities = ("video", "text")
|
|
1977
1994
|
_can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamVideoTwoWayAttentionBlock, index=2)}
|
|
1978
1995
|
_keys_to_ignore_on_load_unexpected = []
|
|
1979
|
-
_tied_weights_keys = {
|
|
1980
|
-
"prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding"
|
|
1981
|
-
}
|
|
1982
|
-
# need to be ignored, as it's a buffer and will not be correctly detected as tied weight
|
|
1983
|
-
_keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
|
|
1984
1996
|
|
|
1985
1997
|
def __init__(self, config: EdgeTamVideoConfig):
|
|
1986
1998
|
super().__init__(config)
|
|
@@ -2117,6 +2129,7 @@ class EdgeTamVideoModel(EdgeTamVideoPreTrainedModel):
|
|
|
2117
2129
|
frame_idx: Optional[int] = None,
|
|
2118
2130
|
frame: Optional[torch.Tensor] = None,
|
|
2119
2131
|
reverse: bool = False,
|
|
2132
|
+
**kwargs,
|
|
2120
2133
|
) -> EdgeTamVideoSegmentationOutput:
|
|
2121
2134
|
r"""
|
|
2122
2135
|
inference_session (`EdgeTamVideoInferenceSession`):
|
|
@@ -29,6 +29,7 @@ from transformers.models.sam2.modeling_sam2 import (
|
|
|
29
29
|
)
|
|
30
30
|
from transformers.utils.generic import OutputRecorder
|
|
31
31
|
|
|
32
|
+
from ... import initialization as init
|
|
32
33
|
from ...activations import ACT2FN
|
|
33
34
|
from ...configuration_utils import PreTrainedConfig
|
|
34
35
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
@@ -375,24 +376,17 @@ class EdgeTamVideoVisionEncoderOutput(Sam2VideoVisionEncoderOutput):
|
|
|
375
376
|
class EdgeTamVideoVisionRotaryEmbedding(Sam2VideoVisionRotaryEmbedding):
|
|
376
377
|
def __init__(self, config: EdgeTamVideoConfig, end_x: Optional[int] = None, end_y: Optional[int] = None):
|
|
377
378
|
nn.Module.__init__()
|
|
378
|
-
dim = config.memory_attention_hidden_size // (
|
|
379
|
+
self.dim = config.memory_attention_hidden_size // (
|
|
379
380
|
config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
|
|
380
381
|
)
|
|
381
382
|
# Ensure even dimension for proper axial splitting
|
|
382
|
-
if dim % 4 != 0:
|
|
383
|
+
if self.dim % 4 != 0:
|
|
383
384
|
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
|
|
384
|
-
end_x, end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
# Generate 2D position indices for axial rotary embedding
|
|
388
|
-
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
389
|
-
x_positions = flattened_indices % end_x
|
|
390
|
-
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
|
|
391
|
-
freqs_x = torch.outer(x_positions, freqs).float()
|
|
392
|
-
freqs_y = torch.outer(y_positions, freqs).float()
|
|
393
|
-
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
394
|
-
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
385
|
+
self.end_x, self.end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
|
|
386
|
+
self.memory_attention_rope_theta = config.memory_attention_rope_theta
|
|
387
|
+
|
|
395
388
|
# directly register the cos and sin embeddings as we have a fixed feature shape
|
|
389
|
+
inv_freq = self.create_inv_freq()
|
|
396
390
|
self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
|
|
397
391
|
self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
|
|
398
392
|
|
|
@@ -662,7 +656,12 @@ class EdgeTamVideoFeedForward(Sam2VideoFeedForward):
|
|
|
662
656
|
|
|
663
657
|
|
|
664
658
|
class EdgeTamVideoPreTrainedModel(Sam2VideoPreTrainedModel):
|
|
665
|
-
|
|
659
|
+
def _init_weights(self, module):
|
|
660
|
+
super()._init_weights()
|
|
661
|
+
if isinstance(module, EdgeTamVideoVisionRotaryEmbedding):
|
|
662
|
+
inv_freq = module.create_inv_freq()
|
|
663
|
+
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
|
|
664
|
+
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
|
|
666
665
|
|
|
667
666
|
|
|
668
667
|
class EdgeTamVideoInferenceSession(Sam2VideoInferenceSession):
|
|
@@ -1040,11 +1039,6 @@ class EdgeTamVideoSegmentationOutput(Sam2VideoSegmentationOutput):
|
|
|
1040
1039
|
|
|
1041
1040
|
@auto_docstring
|
|
1042
1041
|
class EdgeTamVideoModel(Sam2VideoModel):
|
|
1043
|
-
_tied_weights_keys = {
|
|
1044
|
-
"prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding"
|
|
1045
|
-
}
|
|
1046
|
-
# need to be ignored, as it's a buffer and will not be correctly detected as tied weight
|
|
1047
|
-
_keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
|
|
1048
1042
|
_keys_to_ignore_on_load_unexpected = []
|
|
1049
1043
|
_can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamVideoTwoWayAttentionBlock, index=2)}
|
|
1050
1044
|
|
|
@@ -1256,6 +1250,7 @@ class EdgeTamVideoModel(Sam2VideoModel):
|
|
|
1256
1250
|
frame_idx: Optional[int] = None,
|
|
1257
1251
|
frame: Optional[torch.Tensor] = None,
|
|
1258
1252
|
reverse: bool = False,
|
|
1253
|
+
**kwargs,
|
|
1259
1254
|
) -> EdgeTamVideoSegmentationOutput:
|
|
1260
1255
|
r"""
|
|
1261
1256
|
inference_session (`EdgeTamVideoInferenceSession`):
|
|
@@ -153,9 +153,8 @@ class EfficientLoFTRImageProcessorFast(BaseImageProcessorFast):
|
|
|
153
153
|
stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]
|
|
154
154
|
|
|
155
155
|
# Return in same format as slow processor
|
|
156
|
-
image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs
|
|
157
156
|
|
|
158
|
-
return BatchFeature(data={"pixel_values":
|
|
157
|
+
return BatchFeature(data={"pixel_values": stacked_pairs}, tensor_type=return_tensors)
|
|
159
158
|
|
|
160
159
|
def post_process_keypoint_matching(
|
|
161
160
|
self,
|
|
@@ -33,7 +33,7 @@ from ...utils import (
|
|
|
33
33
|
can_return_tuple,
|
|
34
34
|
torch_int,
|
|
35
35
|
)
|
|
36
|
-
from ...utils.generic import check_model_inputs
|
|
36
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
37
37
|
from .configuration_efficientloftr import EfficientLoFTRConfig
|
|
38
38
|
|
|
39
39
|
|
|
@@ -103,7 +103,7 @@ class EfficientLoFTRRotaryEmbedding(nn.Module):
|
|
|
103
103
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
104
104
|
|
|
105
105
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
106
|
-
self.original_inv_freq =
|
|
106
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
107
107
|
|
|
108
108
|
@staticmethod
|
|
109
109
|
# Ignore copy
|
|
@@ -147,7 +147,7 @@ class EfficientLoFTRRotaryEmbedding(nn.Module):
|
|
|
147
147
|
embed_height = (feats_height - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
|
|
148
148
|
embed_width = (feats_width - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
|
|
149
149
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
150
|
-
with
|
|
150
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
151
151
|
emb = compute_embeddings(self.inv_freq, embed_height, embed_width, self.config.hidden_size)
|
|
152
152
|
sin = emb.sin()
|
|
153
153
|
cos = emb.cos()
|
|
@@ -684,9 +684,22 @@ class EfficientLoFTRPreTrainedModel(PreTrainedModel):
|
|
|
684
684
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
685
685
|
if module.bias is not None:
|
|
686
686
|
init.zeros_(module.bias)
|
|
687
|
+
if getattr(module, "running_mean", None) is not None:
|
|
688
|
+
init.zeros_(module.running_mean)
|
|
689
|
+
init.ones_(module.running_var)
|
|
690
|
+
init.zeros_(module.num_batches_tracked)
|
|
687
691
|
elif isinstance(module, nn.LayerNorm):
|
|
688
692
|
init.zeros_(module.bias)
|
|
689
693
|
init.ones_(module.weight)
|
|
694
|
+
elif isinstance(module, EfficientLoFTRRotaryEmbedding):
|
|
695
|
+
rope_fn = (
|
|
696
|
+
ROPE_INIT_FUNCTIONS[module.rope_type]
|
|
697
|
+
if module.rope_type != "default"
|
|
698
|
+
else module.compute_default_rope_parameters
|
|
699
|
+
)
|
|
700
|
+
buffer_value, _ = rope_fn(module.config)
|
|
701
|
+
init.copy_(module.inv_freq, buffer_value)
|
|
702
|
+
init.copy_(module.original_inv_freq, buffer_value)
|
|
690
703
|
|
|
691
704
|
# Copied from transformers.models.superpoint.modeling_superpoint.SuperPointPreTrainedModel.extract_one_channel_pixel_values with SuperPoint->EfficientLoFTR
|
|
692
705
|
def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
|
|
@@ -66,7 +66,7 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
|
|
66
66
|
`do_resize` in `preprocess`.
|
|
67
67
|
size (`dict[str, int]` *optional*, defaults to `{"height": 346, "width": 346}`):
|
|
68
68
|
Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
|
|
69
|
-
resample (`PILImageResampling` filter, *optional*, defaults to
|
|
69
|
+
resample (`PILImageResampling` filter, *optional*, defaults to `Resampling.BICUBIC`):
|
|
70
70
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
|
|
71
71
|
do_center_crop (`bool`, *optional*, defaults to `False`):
|
|
72
72
|
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
|
|
@@ -102,7 +102,7 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
|
|
102
102
|
self,
|
|
103
103
|
do_resize: bool = True,
|
|
104
104
|
size: Optional[dict[str, int]] = None,
|
|
105
|
-
resample: PILImageResampling =
|
|
105
|
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
|
106
106
|
do_center_crop: bool = False,
|
|
107
107
|
crop_size: Optional[dict[str, int]] = None,
|
|
108
108
|
rescale_factor: Union[int, float] = 1 / 255,
|
|
@@ -133,12 +133,11 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
|
|
133
133
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
|
134
134
|
self.include_top = include_top
|
|
135
135
|
|
|
136
|
-
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.NEAREST
|
|
137
136
|
def resize(
|
|
138
137
|
self,
|
|
139
138
|
image: np.ndarray,
|
|
140
139
|
size: dict[str, int],
|
|
141
|
-
resample: PILImageResampling = PILImageResampling.
|
|
140
|
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
|
142
141
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
|
143
142
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
|
144
143
|
**kwargs,
|
|
@@ -151,8 +150,8 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
|
|
151
150
|
Image to resize.
|
|
152
151
|
size (`dict[str, int]`):
|
|
153
152
|
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
|
154
|
-
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.
|
|
155
|
-
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.
|
|
153
|
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
|
154
|
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
|
|
156
155
|
data_format (`ChannelDimension` or `str`, *optional*):
|
|
157
156
|
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
|
158
157
|
image is used. Can be one of:
|
|
@@ -33,7 +33,7 @@ from .image_processing_efficientnet import EfficientNetImageProcessorKwargs
|
|
|
33
33
|
|
|
34
34
|
@auto_docstring
|
|
35
35
|
class EfficientNetImageProcessorFast(BaseImageProcessorFast):
|
|
36
|
-
resample = PILImageResampling.
|
|
36
|
+
resample = PILImageResampling.BICUBIC
|
|
37
37
|
image_mean = IMAGENET_STANDARD_MEAN
|
|
38
38
|
image_std = IMAGENET_STANDARD_STD
|
|
39
39
|
size = {"height": 346, "width": 346}
|
|
@@ -178,7 +178,6 @@ class EfficientNetImageProcessorFast(BaseImageProcessorFast):
|
|
|
178
178
|
processed_images_grouped[shape] = stacked_images
|
|
179
179
|
|
|
180
180
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
181
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
182
181
|
|
|
183
182
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
184
183
|
|
|
@@ -435,7 +435,7 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
|
|
|
435
435
|
base_model_prefix = "efficientnet"
|
|
436
436
|
main_input_name = "pixel_values"
|
|
437
437
|
input_modalities = ("image",)
|
|
438
|
-
_no_split_modules = []
|
|
438
|
+
_no_split_modules = ["EfficientNetBlock"]
|
|
439
439
|
|
|
440
440
|
@torch.no_grad()
|
|
441
441
|
def _init_weights(self, module: nn.Module):
|
|
@@ -444,6 +444,10 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
|
|
|
444
444
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
445
445
|
if module.bias is not None:
|
|
446
446
|
init.zeros_(module.bias)
|
|
447
|
+
if getattr(module, "running_mean", None) is not None:
|
|
448
|
+
init.zeros_(module.running_mean)
|
|
449
|
+
init.ones_(module.running_var)
|
|
450
|
+
init.zeros_(module.num_batches_tracked)
|
|
447
451
|
|
|
448
452
|
|
|
449
453
|
@auto_docstring
|
|
@@ -471,6 +475,7 @@ class EfficientNetModel(EfficientNetPreTrainedModel):
|
|
|
471
475
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
472
476
|
output_hidden_states: Optional[bool] = None,
|
|
473
477
|
return_dict: Optional[bool] = None,
|
|
478
|
+
**kwargs,
|
|
474
479
|
) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
|
|
475
480
|
output_hidden_states = (
|
|
476
481
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -529,6 +534,7 @@ class EfficientNetForImageClassification(EfficientNetPreTrainedModel):
|
|
|
529
534
|
labels: Optional[torch.LongTensor] = None,
|
|
530
535
|
output_hidden_states: Optional[bool] = None,
|
|
531
536
|
return_dict: Optional[bool] = None,
|
|
537
|
+
**kwargs,
|
|
532
538
|
) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
|
|
533
539
|
r"""
|
|
534
540
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -22,6 +22,7 @@ import torch
|
|
|
22
22
|
from torch import nn
|
|
23
23
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
24
24
|
|
|
25
|
+
from ... import initialization as init
|
|
25
26
|
from ...activations import ACT2FN, get_activation
|
|
26
27
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
27
28
|
from ...generation import GenerationMixin
|
|
@@ -532,6 +533,12 @@ class ElectraPreTrainedModel(PreTrainedModel):
|
|
|
532
533
|
"cross_attentions": ElectraCrossAttention,
|
|
533
534
|
}
|
|
534
535
|
|
|
536
|
+
def _init_weights(self, module):
|
|
537
|
+
super()._init_weights(module)
|
|
538
|
+
if isinstance(module, ElectraEmbeddings):
|
|
539
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
540
|
+
init.zeros_(module.token_type_ids)
|
|
541
|
+
|
|
535
542
|
|
|
536
543
|
@dataclass
|
|
537
544
|
@auto_docstring(
|