transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -25,9 +25,10 @@ from typing import Optional, Union
|
|
|
25
25
|
import torch
|
|
26
26
|
from torch import nn
|
|
27
27
|
|
|
28
|
+
from ... import initialization as init
|
|
28
29
|
from ...activations import ACT2FN
|
|
29
30
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
30
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
31
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
31
32
|
from ...masking_utils import create_bidirectional_mask, create_causal_mask
|
|
32
33
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
33
34
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -41,6 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
41
42
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
42
43
|
from ...processing_utils import Unpack
|
|
43
44
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
|
45
|
+
from ...utils.generic import maybe_autocast
|
|
44
46
|
from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
|
|
45
47
|
from .generation_dia import DiaGenerationMixin
|
|
46
48
|
|
|
@@ -60,6 +62,12 @@ class DiaPreTrainedModel(PreTrainedModel):
|
|
|
60
62
|
main_input_name = "input_ids"
|
|
61
63
|
_no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
|
|
62
64
|
|
|
65
|
+
def _init_weights(self, module):
|
|
66
|
+
super()._init_weights(module)
|
|
67
|
+
if isinstance(module, DiaMultiChannelEmbedding):
|
|
68
|
+
offsets = torch.arange(self.config.num_channels, dtype=torch.long) * self.config.vocab_size
|
|
69
|
+
init.copy_(module.offsets, offsets)
|
|
70
|
+
|
|
63
71
|
|
|
64
72
|
class DiaMultiChannelEmbedding(nn.Module):
|
|
65
73
|
"""In order to efficiently compute the audio embedding from the 9 different channels,
|
|
@@ -145,7 +153,7 @@ class DiaRotaryEmbedding(nn.Module):
|
|
|
145
153
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
146
154
|
|
|
147
155
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
148
|
-
self.original_inv_freq =
|
|
156
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
149
157
|
|
|
150
158
|
@staticmethod
|
|
151
159
|
def compute_default_rope_parameters(
|
|
@@ -184,7 +192,7 @@ class DiaRotaryEmbedding(nn.Module):
|
|
|
184
192
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
185
193
|
|
|
186
194
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
187
|
-
with
|
|
195
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
188
196
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
189
197
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
190
198
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -266,6 +274,7 @@ def eager_attention_forward(
|
|
|
266
274
|
return attn_output, attn_weights
|
|
267
275
|
|
|
268
276
|
|
|
277
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
269
278
|
class DiaSelfAttention(nn.Module):
|
|
270
279
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
271
280
|
|
|
@@ -450,6 +459,8 @@ class DiaEncoder(DiaPreTrainedModel):
|
|
|
450
459
|
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
|
451
460
|
self.rotary_emb = DiaRotaryEmbedding(config=config)
|
|
452
461
|
|
|
462
|
+
self.post_init()
|
|
463
|
+
|
|
453
464
|
@auto_docstring
|
|
454
465
|
@can_return_tuple
|
|
455
466
|
def forward(
|
|
@@ -523,7 +534,6 @@ class DiaDecoderLayer(GradientCheckpointingLayer):
|
|
|
523
534
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
524
535
|
past_key_values: Optional[EncoderDecoderCache] = None,
|
|
525
536
|
cache_position: Optional[torch.LongTensor] = None,
|
|
526
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
527
537
|
**kwargs,
|
|
528
538
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
529
539
|
self_attn_cache = past_key_values
|
|
@@ -577,6 +587,8 @@ class DiaDecoder(DiaPreTrainedModel):
|
|
|
577
587
|
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
|
578
588
|
self.rotary_emb = DiaRotaryEmbedding(config=config)
|
|
579
589
|
|
|
590
|
+
self.post_init()
|
|
591
|
+
|
|
580
592
|
@auto_docstring
|
|
581
593
|
@can_return_tuple
|
|
582
594
|
def forward(
|
|
@@ -20,6 +20,7 @@ from typing import Optional, Union
|
|
|
20
20
|
import torch
|
|
21
21
|
from torch import nn
|
|
22
22
|
|
|
23
|
+
from ... import initialization as init
|
|
23
24
|
from ...cache_utils import DynamicCache, EncoderDecoderCache
|
|
24
25
|
from ...masking_utils import create_bidirectional_mask, create_causal_mask
|
|
25
26
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
@@ -59,6 +60,12 @@ class DiaPreTrainedModel(PreTrainedModel):
|
|
|
59
60
|
main_input_name = "input_ids"
|
|
60
61
|
_no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
|
|
61
62
|
|
|
63
|
+
def _init_weights(self, module):
|
|
64
|
+
super()._init_weights(module)
|
|
65
|
+
if isinstance(module, DiaMultiChannelEmbedding):
|
|
66
|
+
offsets = torch.arange(self.config.num_channels, dtype=torch.long) * self.config.vocab_size
|
|
67
|
+
init.copy_(module.offsets, offsets)
|
|
68
|
+
|
|
62
69
|
|
|
63
70
|
class DiaMultiChannelEmbedding(nn.Module):
|
|
64
71
|
"""In order to efficiently compute the audio embedding from the 9 different channels,
|
|
@@ -241,6 +248,8 @@ class DiaEncoder(DiaPreTrainedModel):
|
|
|
241
248
|
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
|
242
249
|
self.rotary_emb = DiaRotaryEmbedding(config=config)
|
|
243
250
|
|
|
251
|
+
self.post_init()
|
|
252
|
+
|
|
244
253
|
@auto_docstring
|
|
245
254
|
@can_return_tuple
|
|
246
255
|
def forward(
|
|
@@ -314,7 +323,6 @@ class DiaDecoderLayer(GradientCheckpointingLayer):
|
|
|
314
323
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
315
324
|
past_key_values: Optional[EncoderDecoderCache] = None,
|
|
316
325
|
cache_position: Optional[torch.LongTensor] = None,
|
|
317
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
318
326
|
**kwargs,
|
|
319
327
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
320
328
|
self_attn_cache = past_key_values
|
|
@@ -368,6 +376,8 @@ class DiaDecoder(DiaPreTrainedModel):
|
|
|
368
376
|
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
|
369
377
|
self.rotary_emb = DiaRotaryEmbedding(config=config)
|
|
370
378
|
|
|
379
|
+
self.post_init()
|
|
380
|
+
|
|
371
381
|
@auto_docstring
|
|
372
382
|
@can_return_tuple
|
|
373
383
|
def forward(
|
|
@@ -74,7 +74,7 @@ class DiaProcessor(ProcessorMixin):
|
|
|
74
74
|
tokenizer (`DiaTokenizer`):
|
|
75
75
|
An instance of [`DiaTokenizer`]. The tokenizer is a required input.
|
|
76
76
|
audio_tokenizer (`DacModel`):
|
|
77
|
-
An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is
|
|
77
|
+
An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is a required input.
|
|
78
78
|
"""
|
|
79
79
|
|
|
80
80
|
audio_tokenizer_class = "DacModel"
|
|
@@ -46,7 +46,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
46
46
|
from ...modeling_utils import PreTrainedModel
|
|
47
47
|
from ...processing_utils import Unpack
|
|
48
48
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
49
|
-
from ...utils.generic import check_model_inputs
|
|
49
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
50
50
|
from .configuration_diffllama import DiffLlamaConfig
|
|
51
51
|
|
|
52
52
|
|
|
@@ -86,7 +86,7 @@ class DiffLlamaRotaryEmbedding(nn.Module):
|
|
|
86
86
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
87
87
|
|
|
88
88
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
89
|
-
self.original_inv_freq =
|
|
89
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
90
90
|
|
|
91
91
|
@staticmethod
|
|
92
92
|
def compute_default_rope_parameters(
|
|
@@ -125,7 +125,7 @@ class DiffLlamaRotaryEmbedding(nn.Module):
|
|
|
125
125
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
126
126
|
|
|
127
127
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
128
|
-
with
|
|
128
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
129
129
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
130
130
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
131
131
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -361,8 +361,8 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
|
|
|
361
361
|
else torch.get_autocast_gpu_dtype()
|
|
362
362
|
)
|
|
363
363
|
# Handle the case where the model is quantized
|
|
364
|
-
elif hasattr(self.config, "
|
|
365
|
-
target_dtype = self.config.
|
|
364
|
+
elif hasattr(self.config, "quantization_config"):
|
|
365
|
+
target_dtype = self.config.dtype
|
|
366
366
|
else:
|
|
367
367
|
target_dtype = self.q_proj.weight.dtype
|
|
368
368
|
|
|
@@ -236,8 +236,8 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
|
|
|
236
236
|
else torch.get_autocast_gpu_dtype()
|
|
237
237
|
)
|
|
238
238
|
# Handle the case where the model is quantized
|
|
239
|
-
elif hasattr(self.config, "
|
|
240
|
-
target_dtype = self.config.
|
|
239
|
+
elif hasattr(self.config, "quantization_config"):
|
|
240
|
+
target_dtype = self.config.dtype
|
|
241
241
|
else:
|
|
242
242
|
target_dtype = self.q_proj.weight.dtype
|
|
243
243
|
|
|
@@ -596,6 +596,7 @@ class DinatModel(DinatPreTrainedModel):
|
|
|
596
596
|
output_attentions: Optional[bool] = None,
|
|
597
597
|
output_hidden_states: Optional[bool] = None,
|
|
598
598
|
return_dict: Optional[bool] = None,
|
|
599
|
+
**kwargs,
|
|
599
600
|
) -> Union[tuple, DinatModelOutput]:
|
|
600
601
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
601
602
|
output_hidden_states = (
|
|
@@ -668,6 +669,7 @@ class DinatForImageClassification(DinatPreTrainedModel):
|
|
|
668
669
|
output_attentions: Optional[bool] = None,
|
|
669
670
|
output_hidden_states: Optional[bool] = None,
|
|
670
671
|
return_dict: Optional[bool] = None,
|
|
672
|
+
**kwargs,
|
|
671
673
|
) -> Union[tuple, DinatImageClassifierOutput]:
|
|
672
674
|
r"""
|
|
673
675
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -740,6 +742,7 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
|
|
|
740
742
|
output_hidden_states: Optional[bool] = None,
|
|
741
743
|
output_attentions: Optional[bool] = None,
|
|
742
744
|
return_dict: Optional[bool] = None,
|
|
745
|
+
**kwargs,
|
|
743
746
|
) -> BackboneOutput:
|
|
744
747
|
r"""
|
|
745
748
|
Examples:
|
|
@@ -214,7 +214,7 @@ class DINOv3ConvNextModel(DINOv3ConvNextPreTrainedModel):
|
|
|
214
214
|
@can_return_tuple
|
|
215
215
|
@auto_docstring
|
|
216
216
|
def forward(
|
|
217
|
-
self, pixel_values: torch.FloatTensor, output_hidden_states: Optional[bool] = None
|
|
217
|
+
self, pixel_values: torch.FloatTensor, output_hidden_states: Optional[bool] = None, **kwargs
|
|
218
218
|
) -> BaseModelOutputWithPoolingAndNoAttention:
|
|
219
219
|
hidden_states = pixel_values
|
|
220
220
|
|
|
@@ -88,7 +88,6 @@ class DINOv3ViTImageProcessorFast(BaseImageProcessorFast):
|
|
|
88
88
|
processed_images_grouped[shape] = stacked_images
|
|
89
89
|
|
|
90
90
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
91
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
92
91
|
|
|
93
92
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
94
93
|
|
|
@@ -36,7 +36,7 @@ from ...processing_utils import Unpack
|
|
|
36
36
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
37
37
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
38
38
|
from ...utils.backbone_utils import BackboneMixin
|
|
39
|
-
from ...utils.generic import check_model_inputs
|
|
39
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
40
40
|
from .configuration_dinov3_vit import DINOv3ViTConfig
|
|
41
41
|
|
|
42
42
|
|
|
@@ -156,7 +156,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
|
|
|
156
156
|
device = pixel_values.device
|
|
157
157
|
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
|
|
158
158
|
|
|
159
|
-
with
|
|
159
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
160
160
|
# Although we could precompute static patch_coords from image_size and patch_size in the config,
|
|
161
161
|
# the model was trained with random_scale, so it can process images of varying sizes.
|
|
162
162
|
# Therefore, it's better to compute patch_coords dynamically (with lru_cache).
|
|
@@ -466,6 +466,9 @@ class DINOv3ViTPreTrainedModel(PreTrainedModel):
|
|
|
466
466
|
init.zeros_(module.mask_token)
|
|
467
467
|
elif isinstance(module, DINOv3ViTLayerScale):
|
|
468
468
|
init.constant_(module.lambda1, self.config.layerscale_value)
|
|
469
|
+
elif isinstance(module, DINOv3ViTRopePositionEmbedding):
|
|
470
|
+
inv_freq = 1 / module.base ** torch.arange(0, 1, 4 / module.head_dim, dtype=torch.float32)
|
|
471
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
469
472
|
|
|
470
473
|
|
|
471
474
|
@auto_docstring
|
|
@@ -40,7 +40,7 @@ from ...processing_utils import Unpack
|
|
|
40
40
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
41
41
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
42
42
|
from ...utils.backbone_utils import BackboneMixin
|
|
43
|
-
from ...utils.generic import check_model_inputs
|
|
43
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
44
44
|
from .configuration_dinov3_vit import DINOv3ViTConfig
|
|
45
45
|
|
|
46
46
|
|
|
@@ -163,7 +163,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
|
|
|
163
163
|
device = pixel_values.device
|
|
164
164
|
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
|
|
165
165
|
|
|
166
|
-
with
|
|
166
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
167
167
|
# Although we could precompute static patch_coords from image_size and patch_size in the config,
|
|
168
168
|
# the model was trained with random_scale, so it can process images of varying sizes.
|
|
169
169
|
# Therefore, it's better to compute patch_coords dynamically (with lru_cache).
|
|
@@ -361,6 +361,9 @@ class DINOv3ViTPreTrainedModel(Dinov2PreTrainedModel):
|
|
|
361
361
|
init.zeros_(module.mask_token)
|
|
362
362
|
elif isinstance(module, DINOv3ViTLayerScale):
|
|
363
363
|
init.constant_(module.lambda1, self.config.layerscale_value)
|
|
364
|
+
elif isinstance(module, DINOv3ViTRopePositionEmbedding):
|
|
365
|
+
inv_freq = 1 / module.base ** torch.arange(0, 1, 4 / module.head_dim, dtype=torch.float32)
|
|
366
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
364
367
|
|
|
365
368
|
|
|
366
369
|
@auto_docstring
|
|
@@ -305,15 +305,17 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
|
|
305
305
|
def _init_weights(self, module: nn.Module):
|
|
306
306
|
"""Initialize the weights."""
|
|
307
307
|
super()._init_weights(module)
|
|
308
|
-
if isinstance(module, Embeddings)
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
308
|
+
if isinstance(module, Embeddings):
|
|
309
|
+
if self.config.sinusoidal_pos_embds:
|
|
310
|
+
init.copy_(
|
|
311
|
+
module.position_embeddings.weight,
|
|
312
|
+
create_sinusoidal_embeddings(
|
|
313
|
+
self.config.max_position_embeddings,
|
|
314
|
+
self.config.dim,
|
|
315
|
+
torch.empty_like(module.position_embeddings.weight),
|
|
316
|
+
),
|
|
317
|
+
)
|
|
318
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
317
319
|
|
|
318
320
|
|
|
319
321
|
@auto_docstring
|
|
@@ -23,6 +23,19 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso
|
|
|
23
23
|
class DistilBertTokenizer(BertTokenizer):
|
|
24
24
|
model_input_names = ["input_ids", "attention_mask"]
|
|
25
25
|
|
|
26
|
+
def __init__(self, *args, do_lower_case: bool = True, **kwargs):
|
|
27
|
+
"""
|
|
28
|
+
Construct a DistilBERT tokenizer (backed by HuggingFace's tokenizers library). Based on WordPiece.
|
|
29
|
+
|
|
30
|
+
This tokenizer inherits from [`BertTokenizer`] which contains most of the main methods. Users should refer to
|
|
31
|
+
this superclass for more information regarding those methods.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
|
35
|
+
Whether or not to lowercase the input when tokenizing.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(*args, do_lower_case=do_lower_case, **kwargs)
|
|
38
|
+
|
|
26
39
|
|
|
27
40
|
# DistilBertTokenizerFast is an alias for DistilBertTokenizer (since BertTokenizer is already a fast tokenizer)
|
|
28
41
|
DistilBertTokenizerFast = DistilBertTokenizer
|
|
@@ -42,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
42
42
|
from ...modeling_utils import AttentionInterface, PreTrainedModel
|
|
43
43
|
from ...processing_utils import Unpack
|
|
44
44
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available
|
|
45
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
45
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
46
46
|
from .configuration_doge import DogeConfig
|
|
47
47
|
|
|
48
48
|
|
|
@@ -88,7 +88,7 @@ class DogeRotaryEmbedding(nn.Module):
|
|
|
88
88
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
89
89
|
|
|
90
90
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
91
|
-
self.original_inv_freq =
|
|
91
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
92
92
|
|
|
93
93
|
@staticmethod
|
|
94
94
|
def compute_default_rope_parameters(
|
|
@@ -127,7 +127,7 @@ class DogeRotaryEmbedding(nn.Module):
|
|
|
127
127
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
128
128
|
|
|
129
129
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
130
|
-
with
|
|
130
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
131
131
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
132
132
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
133
133
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -297,7 +297,6 @@ class DogeAttention(nn.Module):
|
|
|
297
297
|
attention_mask: Optional[torch.Tensor] = None,
|
|
298
298
|
past_key_values: Optional[Cache] = None,
|
|
299
299
|
cache_position: Optional[torch.LongTensor] = None,
|
|
300
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
301
300
|
**kwargs,
|
|
302
301
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
303
302
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -321,7 +321,6 @@ class DogeAttention(nn.Module):
|
|
|
321
321
|
attention_mask: Optional[torch.Tensor] = None,
|
|
322
322
|
past_key_values: Optional[Cache] = None,
|
|
323
323
|
cache_position: Optional[torch.LongTensor] = None,
|
|
324
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
325
324
|
**kwargs,
|
|
326
325
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
327
326
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -231,7 +231,6 @@ class DonutImageProcessorFast(BaseImageProcessorFast):
|
|
|
231
231
|
processed_images_grouped[shape] = stacked_images
|
|
232
232
|
|
|
233
233
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
234
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
235
234
|
|
|
236
235
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
237
236
|
|
|
@@ -381,18 +381,7 @@ class DonutSwinSelfAttention(nn.Module):
|
|
|
381
381
|
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
|
|
382
382
|
)
|
|
383
383
|
|
|
384
|
-
|
|
385
|
-
coords_h = torch.arange(self.window_size[0])
|
|
386
|
-
coords_w = torch.arange(self.window_size[1])
|
|
387
|
-
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
|
388
|
-
coords_flatten = torch.flatten(coords, 1)
|
|
389
|
-
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
390
|
-
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
|
391
|
-
relative_coords[:, :, 0] += self.window_size[0] - 1
|
|
392
|
-
relative_coords[:, :, 1] += self.window_size[1] - 1
|
|
393
|
-
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
|
394
|
-
relative_position_index = relative_coords.sum(-1)
|
|
395
|
-
self.register_buffer("relative_position_index", relative_position_index)
|
|
384
|
+
self.register_buffer("relative_position_index", self.create_relative_position_index())
|
|
396
385
|
|
|
397
386
|
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
|
398
387
|
self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
|
@@ -451,6 +440,20 @@ class DonutSwinSelfAttention(nn.Module):
|
|
|
451
440
|
|
|
452
441
|
return outputs
|
|
453
442
|
|
|
443
|
+
def create_relative_position_index(self):
|
|
444
|
+
# get pair-wise relative position index for each token inside the window
|
|
445
|
+
coords_h = torch.arange(self.window_size[0])
|
|
446
|
+
coords_w = torch.arange(self.window_size[1])
|
|
447
|
+
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
|
448
|
+
coords_flatten = torch.flatten(coords, 1)
|
|
449
|
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
450
|
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
|
451
|
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
|
452
|
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
|
453
|
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
|
454
|
+
relative_position_index = relative_coords.sum(-1)
|
|
455
|
+
return relative_position_index
|
|
456
|
+
|
|
454
457
|
|
|
455
458
|
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
|
|
456
459
|
class DonutSwinSelfOutput(nn.Module):
|
|
@@ -801,6 +804,7 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
|
|
|
801
804
|
init.zeros_(module.position_embeddings)
|
|
802
805
|
elif isinstance(module, DonutSwinSelfAttention):
|
|
803
806
|
init.zeros_(module.relative_position_bias_table)
|
|
807
|
+
init.copy_(module.relative_position_index, module.create_relative_position_index())
|
|
804
808
|
|
|
805
809
|
|
|
806
810
|
@auto_docstring
|
|
@@ -837,6 +841,7 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
|
|
|
837
841
|
output_hidden_states: Optional[bool] = None,
|
|
838
842
|
interpolate_pos_encoding: bool = False,
|
|
839
843
|
return_dict: Optional[bool] = None,
|
|
844
|
+
**kwargs,
|
|
840
845
|
) -> Union[tuple, DonutSwinModelOutput]:
|
|
841
846
|
r"""
|
|
842
847
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
|
@@ -923,6 +928,7 @@ class DonutSwinForImageClassification(DonutSwinPreTrainedModel):
|
|
|
923
928
|
output_hidden_states: Optional[bool] = None,
|
|
924
929
|
interpolate_pos_encoding: bool = False,
|
|
925
930
|
return_dict: Optional[bool] = None,
|
|
931
|
+
**kwargs,
|
|
926
932
|
) -> Union[tuple, DonutSwinImageClassifierOutput]:
|
|
927
933
|
r"""
|
|
928
934
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -29,7 +29,12 @@ from ... import initialization as init
|
|
|
29
29
|
from ...activations import ACT2FN
|
|
30
30
|
from ...cache_utils import Cache, DynamicCache
|
|
31
31
|
from ...generation import GenerationMixin
|
|
32
|
-
from ...integrations import
|
|
32
|
+
from ...integrations import (
|
|
33
|
+
use_experts_implementation,
|
|
34
|
+
use_kernel_forward_from_hub,
|
|
35
|
+
use_kernel_func_from_hub,
|
|
36
|
+
use_kernelized_func,
|
|
37
|
+
)
|
|
33
38
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
34
39
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
35
40
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -37,8 +42,8 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
|
37
42
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
38
43
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
44
|
from ...processing_utils import Unpack
|
|
40
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
41
|
-
from ...utils.generic import check_model_inputs
|
|
45
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
46
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
42
47
|
from .configuration_dots1 import Dots1Config
|
|
43
48
|
|
|
44
49
|
|
|
@@ -80,7 +85,7 @@ class Dots1RotaryEmbedding(nn.Module):
|
|
|
80
85
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
81
86
|
|
|
82
87
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
83
|
-
self.original_inv_freq =
|
|
88
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
84
89
|
|
|
85
90
|
@staticmethod
|
|
86
91
|
def compute_default_rope_parameters(
|
|
@@ -119,7 +124,7 @@ class Dots1RotaryEmbedding(nn.Module):
|
|
|
119
124
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
120
125
|
|
|
121
126
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
122
|
-
with
|
|
127
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
123
128
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
124
129
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
125
130
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -201,6 +206,7 @@ def eager_attention_forward(
|
|
|
201
206
|
return attn_output, attn_weights
|
|
202
207
|
|
|
203
208
|
|
|
209
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
204
210
|
class Dots1Attention(nn.Module):
|
|
205
211
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
206
212
|
|
|
@@ -227,7 +233,6 @@ class Dots1Attention(nn.Module):
|
|
|
227
233
|
self.o_proj = nn.Linear(
|
|
228
234
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
229
235
|
)
|
|
230
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
231
236
|
self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
|
232
237
|
self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
|
|
233
238
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
@@ -308,6 +313,7 @@ class Dots1TopkRouter(nn.Module):
|
|
|
308
313
|
return router_logits
|
|
309
314
|
|
|
310
315
|
|
|
316
|
+
@use_experts_implementation
|
|
311
317
|
class Dots1NaiveMoe(nn.Module):
|
|
312
318
|
"""Collection of expert weights stored as 3D tensors."""
|
|
313
319
|
|
|
@@ -315,7 +321,7 @@ class Dots1NaiveMoe(nn.Module):
|
|
|
315
321
|
super().__init__()
|
|
316
322
|
self.num_experts = config.num_local_experts
|
|
317
323
|
self.hidden_dim = config.hidden_size
|
|
318
|
-
self.intermediate_dim = config.
|
|
324
|
+
self.intermediate_dim = config.moe_intermediate_size
|
|
319
325
|
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
|
320
326
|
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
|
321
327
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
@@ -369,9 +375,11 @@ class Dots1MoE(nn.Module):
|
|
|
369
375
|
|
|
370
376
|
def route_tokens_to_experts(self, router_logits):
|
|
371
377
|
router_logits = router_logits.sigmoid() # main diff with deepseekv3
|
|
372
|
-
|
|
378
|
+
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
|
|
373
379
|
group_scores = (
|
|
374
|
-
|
|
380
|
+
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
|
381
|
+
.topk(2, dim=-1)[0]
|
|
382
|
+
.sum(dim=-1)
|
|
375
383
|
)
|
|
376
384
|
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
|
377
385
|
group_mask = torch.zeros_like(group_scores)
|
|
@@ -381,7 +389,7 @@ class Dots1MoE(nn.Module):
|
|
|
381
389
|
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
|
382
390
|
.reshape(-1, self.n_routed_experts)
|
|
383
391
|
)
|
|
384
|
-
scores_for_choice =
|
|
392
|
+
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
|
385
393
|
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
|
386
394
|
topk_weights = router_logits.gather(1, topk_indices)
|
|
387
395
|
if self.norm_topk_prob:
|
|
@@ -461,18 +469,22 @@ class Dots1PreTrainedModel(PreTrainedModel):
|
|
|
461
469
|
_supports_flash_attn = True
|
|
462
470
|
_supports_sdpa = True
|
|
463
471
|
_supports_flex_attn = True
|
|
464
|
-
_can_compile_fullgraph =
|
|
472
|
+
_can_compile_fullgraph = (
|
|
473
|
+
is_grouped_mm_available()
|
|
474
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
465
475
|
_supports_attention_backend = True
|
|
466
476
|
_can_record_outputs = {
|
|
467
477
|
"hidden_states": Dots1DecoderLayer,
|
|
468
478
|
"attentions": Dots1Attention,
|
|
469
479
|
}
|
|
480
|
+
_keep_in_fp32_modules_strict = ["e_score_correction_bias"]
|
|
470
481
|
|
|
471
482
|
@torch.no_grad()
|
|
472
483
|
def _init_weights(self, module):
|
|
473
484
|
super()._init_weights(module)
|
|
474
485
|
if isinstance(module, Dots1TopkRouter):
|
|
475
486
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
487
|
+
init.zeros_(module.e_score_correction_bias)
|
|
476
488
|
elif isinstance(module, Dots1NaiveMoe):
|
|
477
489
|
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
|
|
478
490
|
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
|
|
@@ -61,9 +61,11 @@ class Dots1TopkRouter(DeepseekV3TopkRouter):
|
|
|
61
61
|
class Dots1MoE(DeepseekV3MoE):
|
|
62
62
|
def route_tokens_to_experts(self, router_logits):
|
|
63
63
|
router_logits = router_logits.sigmoid() # main diff with deepseekv3
|
|
64
|
-
|
|
64
|
+
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
|
|
65
65
|
group_scores = (
|
|
66
|
-
|
|
66
|
+
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
|
67
|
+
.topk(2, dim=-1)[0]
|
|
68
|
+
.sum(dim=-1)
|
|
67
69
|
)
|
|
68
70
|
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
|
69
71
|
group_mask = torch.zeros_like(group_scores)
|
|
@@ -73,7 +75,7 @@ class Dots1MoE(DeepseekV3MoE):
|
|
|
73
75
|
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
|
74
76
|
.reshape(-1, self.n_routed_experts)
|
|
75
77
|
)
|
|
76
|
-
scores_for_choice =
|
|
78
|
+
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
|
77
79
|
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
|
78
80
|
topk_weights = router_logits.gather(1, topk_indices)
|
|
79
81
|
if self.norm_topk_prob:
|
|
@@ -129,6 +129,7 @@ class DPREncoder(DPRPreTrainedModel):
|
|
|
129
129
|
output_attentions: bool = False,
|
|
130
130
|
output_hidden_states: bool = False,
|
|
131
131
|
return_dict: bool = False,
|
|
132
|
+
**kwargs,
|
|
132
133
|
) -> Union[BaseModelOutputWithPooling, tuple[Tensor, ...]]:
|
|
133
134
|
outputs = self.bert_model(
|
|
134
135
|
input_ids=input_ids,
|
|
@@ -181,6 +182,7 @@ class DPRSpanPredictor(DPRPreTrainedModel):
|
|
|
181
182
|
output_attentions: bool = False,
|
|
182
183
|
output_hidden_states: bool = False,
|
|
183
184
|
return_dict: bool = False,
|
|
185
|
+
**kwargs,
|
|
184
186
|
) -> Union[DPRReaderOutput, tuple[Tensor, ...]]:
|
|
185
187
|
# notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
|
|
186
188
|
n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2]
|
|
@@ -282,6 +284,7 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
|
|
|
282
284
|
output_attentions: Optional[bool] = None,
|
|
283
285
|
output_hidden_states: Optional[bool] = None,
|
|
284
286
|
return_dict: Optional[bool] = None,
|
|
287
|
+
**kwargs,
|
|
285
288
|
) -> Union[DPRContextEncoderOutput, tuple[Tensor, ...]]:
|
|
286
289
|
r"""
|
|
287
290
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -387,6 +390,7 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
|
|
|
387
390
|
output_attentions: Optional[bool] = None,
|
|
388
391
|
output_hidden_states: Optional[bool] = None,
|
|
389
392
|
return_dict: Optional[bool] = None,
|
|
393
|
+
**kwargs,
|
|
390
394
|
) -> Union[DPRQuestionEncoderOutput, tuple[Tensor, ...]]:
|
|
391
395
|
r"""
|
|
392
396
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -492,6 +496,7 @@ class DPRReader(DPRPretrainedReader):
|
|
|
492
496
|
output_attentions: Optional[bool] = None,
|
|
493
497
|
output_hidden_states: Optional[bool] = None,
|
|
494
498
|
return_dict: Optional[bool] = None,
|
|
499
|
+
**kwargs,
|
|
495
500
|
) -> Union[DPRReaderOutput, tuple[Tensor, ...]]:
|
|
496
501
|
r"""
|
|
497
502
|
input_ids (`tuple[torch.LongTensor]` of shapes `(n_passages, sequence_length)`):
|