transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -36,7 +36,7 @@ from ...processing_utils import Unpack
|
|
|
36
36
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
37
37
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
38
38
|
from ...utils.backbone_utils import BackboneMixin
|
|
39
|
-
from ...utils.generic import check_model_inputs
|
|
39
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
40
40
|
from .configuration_dinov3_vit import DINOv3ViTConfig
|
|
41
41
|
|
|
42
42
|
|
|
@@ -156,7 +156,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
|
|
|
156
156
|
device = pixel_values.device
|
|
157
157
|
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
|
|
158
158
|
|
|
159
|
-
with
|
|
159
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
160
160
|
# Although we could precompute static patch_coords from image_size and patch_size in the config,
|
|
161
161
|
# the model was trained with random_scale, so it can process images of varying sizes.
|
|
162
162
|
# Therefore, it's better to compute patch_coords dynamically (with lru_cache).
|
|
@@ -40,7 +40,7 @@ from ...processing_utils import Unpack
|
|
|
40
40
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
41
41
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
42
42
|
from ...utils.backbone_utils import BackboneMixin
|
|
43
|
-
from ...utils.generic import check_model_inputs
|
|
43
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
44
44
|
from .configuration_dinov3_vit import DINOv3ViTConfig
|
|
45
45
|
|
|
46
46
|
|
|
@@ -163,7 +163,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
|
|
|
163
163
|
device = pixel_values.device
|
|
164
164
|
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
|
|
165
165
|
|
|
166
|
-
with
|
|
166
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
167
167
|
# Although we could precompute static patch_coords from image_size and patch_size in the config,
|
|
168
168
|
# the model was trained with random_scale, so it can process images of varying sizes.
|
|
169
169
|
# Therefore, it's better to compute patch_coords dynamically (with lru_cache).
|
|
@@ -23,6 +23,19 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso
|
|
|
23
23
|
class DistilBertTokenizer(BertTokenizer):
|
|
24
24
|
model_input_names = ["input_ids", "attention_mask"]
|
|
25
25
|
|
|
26
|
+
def __init__(self, *args, do_lower_case: bool = True, **kwargs):
|
|
27
|
+
"""
|
|
28
|
+
Construct a DistilBERT tokenizer (backed by HuggingFace's tokenizers library). Based on WordPiece.
|
|
29
|
+
|
|
30
|
+
This tokenizer inherits from [`BertTokenizer`] which contains most of the main methods. Users should refer to
|
|
31
|
+
this superclass for more information regarding those methods.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
|
35
|
+
Whether or not to lowercase the input when tokenizing.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(*args, do_lower_case=do_lower_case, **kwargs)
|
|
38
|
+
|
|
26
39
|
|
|
27
40
|
# DistilBertTokenizerFast is an alias for DistilBertTokenizer (since BertTokenizer is already a fast tokenizer)
|
|
28
41
|
DistilBertTokenizerFast = DistilBertTokenizer
|
|
@@ -42,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
42
42
|
from ...modeling_utils import AttentionInterface, PreTrainedModel
|
|
43
43
|
from ...processing_utils import Unpack
|
|
44
44
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available
|
|
45
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
45
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
46
46
|
from .configuration_doge import DogeConfig
|
|
47
47
|
|
|
48
48
|
|
|
@@ -127,7 +127,7 @@ class DogeRotaryEmbedding(nn.Module):
|
|
|
127
127
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
128
128
|
|
|
129
129
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
130
|
-
with
|
|
130
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
131
131
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
132
132
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
133
133
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -297,7 +297,6 @@ class DogeAttention(nn.Module):
|
|
|
297
297
|
attention_mask: Optional[torch.Tensor] = None,
|
|
298
298
|
past_key_values: Optional[Cache] = None,
|
|
299
299
|
cache_position: Optional[torch.LongTensor] = None,
|
|
300
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
301
300
|
**kwargs,
|
|
302
301
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
303
302
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -321,7 +321,6 @@ class DogeAttention(nn.Module):
|
|
|
321
321
|
attention_mask: Optional[torch.Tensor] = None,
|
|
322
322
|
past_key_values: Optional[Cache] = None,
|
|
323
323
|
cache_position: Optional[torch.LongTensor] = None,
|
|
324
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
325
324
|
**kwargs,
|
|
326
325
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
327
326
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -837,6 +837,7 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
|
|
|
837
837
|
output_hidden_states: Optional[bool] = None,
|
|
838
838
|
interpolate_pos_encoding: bool = False,
|
|
839
839
|
return_dict: Optional[bool] = None,
|
|
840
|
+
**kwargs,
|
|
840
841
|
) -> Union[tuple, DonutSwinModelOutput]:
|
|
841
842
|
r"""
|
|
842
843
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
|
@@ -923,6 +924,7 @@ class DonutSwinForImageClassification(DonutSwinPreTrainedModel):
|
|
|
923
924
|
output_hidden_states: Optional[bool] = None,
|
|
924
925
|
interpolate_pos_encoding: bool = False,
|
|
925
926
|
return_dict: Optional[bool] = None,
|
|
927
|
+
**kwargs,
|
|
926
928
|
) -> Union[tuple, DonutSwinImageClassifierOutput]:
|
|
927
929
|
r"""
|
|
928
930
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -29,7 +29,7 @@ from ... import initialization as init
|
|
|
29
29
|
from ...activations import ACT2FN
|
|
30
30
|
from ...cache_utils import Cache, DynamicCache
|
|
31
31
|
from ...generation import GenerationMixin
|
|
32
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
32
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
33
33
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
34
34
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
35
35
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
38
38
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
40
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
41
|
-
from ...utils.generic import check_model_inputs
|
|
41
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
42
42
|
from .configuration_dots1 import Dots1Config
|
|
43
43
|
|
|
44
44
|
|
|
@@ -119,7 +119,7 @@ class Dots1RotaryEmbedding(nn.Module):
|
|
|
119
119
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
120
120
|
|
|
121
121
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
122
|
-
with
|
|
122
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
123
123
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
124
124
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
125
125
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -201,6 +201,7 @@ def eager_attention_forward(
|
|
|
201
201
|
return attn_output, attn_weights
|
|
202
202
|
|
|
203
203
|
|
|
204
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
204
205
|
class Dots1Attention(nn.Module):
|
|
205
206
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
206
207
|
|
|
@@ -227,7 +228,6 @@ class Dots1Attention(nn.Module):
|
|
|
227
228
|
self.o_proj = nn.Linear(
|
|
228
229
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
229
230
|
)
|
|
230
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
231
231
|
self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
|
232
232
|
self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
|
|
233
233
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
@@ -369,9 +369,11 @@ class Dots1MoE(nn.Module):
|
|
|
369
369
|
|
|
370
370
|
def route_tokens_to_experts(self, router_logits):
|
|
371
371
|
router_logits = router_logits.sigmoid() # main diff with deepseekv3
|
|
372
|
-
|
|
372
|
+
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
|
|
373
373
|
group_scores = (
|
|
374
|
-
|
|
374
|
+
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
|
375
|
+
.topk(2, dim=-1)[0]
|
|
376
|
+
.sum(dim=-1)
|
|
375
377
|
)
|
|
376
378
|
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
|
377
379
|
group_mask = torch.zeros_like(group_scores)
|
|
@@ -381,7 +383,7 @@ class Dots1MoE(nn.Module):
|
|
|
381
383
|
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
|
382
384
|
.reshape(-1, self.n_routed_experts)
|
|
383
385
|
)
|
|
384
|
-
scores_for_choice =
|
|
386
|
+
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
|
385
387
|
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
|
386
388
|
topk_weights = router_logits.gather(1, topk_indices)
|
|
387
389
|
if self.norm_topk_prob:
|
|
@@ -467,6 +469,7 @@ class Dots1PreTrainedModel(PreTrainedModel):
|
|
|
467
469
|
"hidden_states": Dots1DecoderLayer,
|
|
468
470
|
"attentions": Dots1Attention,
|
|
469
471
|
}
|
|
472
|
+
_keep_in_fp32_modules_strict = ["e_score_correction_bias"]
|
|
470
473
|
|
|
471
474
|
@torch.no_grad()
|
|
472
475
|
def _init_weights(self, module):
|
|
@@ -61,9 +61,11 @@ class Dots1TopkRouter(DeepseekV3TopkRouter):
|
|
|
61
61
|
class Dots1MoE(DeepseekV3MoE):
|
|
62
62
|
def route_tokens_to_experts(self, router_logits):
|
|
63
63
|
router_logits = router_logits.sigmoid() # main diff with deepseekv3
|
|
64
|
-
|
|
64
|
+
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
|
|
65
65
|
group_scores = (
|
|
66
|
-
|
|
66
|
+
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
|
67
|
+
.topk(2, dim=-1)[0]
|
|
68
|
+
.sum(dim=-1)
|
|
67
69
|
)
|
|
68
70
|
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
|
69
71
|
group_mask = torch.zeros_like(group_scores)
|
|
@@ -73,7 +75,7 @@ class Dots1MoE(DeepseekV3MoE):
|
|
|
73
75
|
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
|
74
76
|
.reshape(-1, self.n_routed_experts)
|
|
75
77
|
)
|
|
76
|
-
scores_for_choice =
|
|
78
|
+
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
|
77
79
|
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
|
78
80
|
topk_weights = router_logits.gather(1, topk_indices)
|
|
79
81
|
if self.norm_topk_prob:
|
|
@@ -129,6 +129,7 @@ class DPREncoder(DPRPreTrainedModel):
|
|
|
129
129
|
output_attentions: bool = False,
|
|
130
130
|
output_hidden_states: bool = False,
|
|
131
131
|
return_dict: bool = False,
|
|
132
|
+
**kwargs,
|
|
132
133
|
) -> Union[BaseModelOutputWithPooling, tuple[Tensor, ...]]:
|
|
133
134
|
outputs = self.bert_model(
|
|
134
135
|
input_ids=input_ids,
|
|
@@ -181,6 +182,7 @@ class DPRSpanPredictor(DPRPreTrainedModel):
|
|
|
181
182
|
output_attentions: bool = False,
|
|
182
183
|
output_hidden_states: bool = False,
|
|
183
184
|
return_dict: bool = False,
|
|
185
|
+
**kwargs,
|
|
184
186
|
) -> Union[DPRReaderOutput, tuple[Tensor, ...]]:
|
|
185
187
|
# notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
|
|
186
188
|
n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2]
|
|
@@ -282,6 +284,7 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
|
|
|
282
284
|
output_attentions: Optional[bool] = None,
|
|
283
285
|
output_hidden_states: Optional[bool] = None,
|
|
284
286
|
return_dict: Optional[bool] = None,
|
|
287
|
+
**kwargs,
|
|
285
288
|
) -> Union[DPRContextEncoderOutput, tuple[Tensor, ...]]:
|
|
286
289
|
r"""
|
|
287
290
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -387,6 +390,7 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
|
|
|
387
390
|
output_attentions: Optional[bool] = None,
|
|
388
391
|
output_hidden_states: Optional[bool] = None,
|
|
389
392
|
return_dict: Optional[bool] = None,
|
|
393
|
+
**kwargs,
|
|
390
394
|
) -> Union[DPRQuestionEncoderOutput, tuple[Tensor, ...]]:
|
|
391
395
|
r"""
|
|
392
396
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -492,6 +496,7 @@ class DPRReader(DPRPretrainedReader):
|
|
|
492
496
|
output_attentions: Optional[bool] = None,
|
|
493
497
|
output_hidden_states: Optional[bool] = None,
|
|
494
498
|
return_dict: Optional[bool] = None,
|
|
499
|
+
**kwargs,
|
|
495
500
|
) -> Union[DPRReaderOutput, tuple[Tensor, ...]]:
|
|
496
501
|
r"""
|
|
497
502
|
input_ids (`tuple[torch.LongTensor]` of shapes `(n_passages, sequence_length)`):
|
|
@@ -39,6 +39,10 @@ class DPRContextEncoderTokenizer(BertTokenizer):
|
|
|
39
39
|
|
|
40
40
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
41
41
|
|
|
42
|
+
def __init__(self, *args, do_lower_case=False, **kwargs):
|
|
43
|
+
super().__init__(*args, **kwargs)
|
|
44
|
+
self.do_lower_case = do_lower_case
|
|
45
|
+
|
|
42
46
|
|
|
43
47
|
class DPRQuestionEncoderTokenizer(BertTokenizer):
|
|
44
48
|
r"""
|
|
@@ -52,6 +56,10 @@ class DPRQuestionEncoderTokenizer(BertTokenizer):
|
|
|
52
56
|
|
|
53
57
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
54
58
|
|
|
59
|
+
def __init__(self, *args, do_lower_case=False, **kwargs):
|
|
60
|
+
super().__init__(*args, **kwargs)
|
|
61
|
+
self.do_lower_case = do_lower_case
|
|
62
|
+
|
|
55
63
|
|
|
56
64
|
DPRSpanPrediction = collections.namedtuple(
|
|
57
65
|
"DPRSpanPrediction", ["span_score", "relevance_score", "doc_id", "start_index", "end_index", "text"]
|
|
@@ -316,5 +324,9 @@ class DPRReaderTokenizer(CustomDPRReaderTokenizerMixin, BertTokenizer):
|
|
|
316
324
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
317
325
|
model_input_names = ["input_ids", "attention_mask"]
|
|
318
326
|
|
|
327
|
+
def __init__(self, *args, do_lower_case=False, **kwargs):
|
|
328
|
+
super().__init__(*args, **kwargs)
|
|
329
|
+
self.do_lower_case = do_lower_case
|
|
330
|
+
|
|
319
331
|
|
|
320
332
|
__all__ = ["DPRContextEncoderTokenizer", "DPRQuestionEncoderTokenizer", "DPRReaderOutput", "DPRReaderTokenizer"]
|
|
@@ -393,7 +393,7 @@ class EdgeTamVisionNeck(nn.Module):
|
|
|
393
393
|
n = len(self.convs) - 1
|
|
394
394
|
for i in range(n, -1, -1):
|
|
395
395
|
lateral_features = hidden_states[i].permute(0, 3, 1, 2)
|
|
396
|
-
lateral_features = self.convs[n - i](lateral_features)
|
|
396
|
+
lateral_features = self.convs[n - i](lateral_features.to(self.convs[i].weight.dtype))
|
|
397
397
|
if i not in self.fpn_top_down_levels or i == n:
|
|
398
398
|
prev_features = lateral_features
|
|
399
399
|
else:
|
|
@@ -2117,6 +2117,7 @@ class EdgeTamVideoModel(EdgeTamVideoPreTrainedModel):
|
|
|
2117
2117
|
frame_idx: Optional[int] = None,
|
|
2118
2118
|
frame: Optional[torch.Tensor] = None,
|
|
2119
2119
|
reverse: bool = False,
|
|
2120
|
+
**kwargs,
|
|
2120
2121
|
) -> EdgeTamVideoSegmentationOutput:
|
|
2121
2122
|
r"""
|
|
2122
2123
|
inference_session (`EdgeTamVideoInferenceSession`):
|
|
@@ -1256,6 +1256,7 @@ class EdgeTamVideoModel(Sam2VideoModel):
|
|
|
1256
1256
|
frame_idx: Optional[int] = None,
|
|
1257
1257
|
frame: Optional[torch.Tensor] = None,
|
|
1258
1258
|
reverse: bool = False,
|
|
1259
|
+
**kwargs,
|
|
1259
1260
|
) -> EdgeTamVideoSegmentationOutput:
|
|
1260
1261
|
r"""
|
|
1261
1262
|
inference_session (`EdgeTamVideoInferenceSession`):
|
|
@@ -33,7 +33,7 @@ from ...utils import (
|
|
|
33
33
|
can_return_tuple,
|
|
34
34
|
torch_int,
|
|
35
35
|
)
|
|
36
|
-
from ...utils.generic import check_model_inputs
|
|
36
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
37
37
|
from .configuration_efficientloftr import EfficientLoFTRConfig
|
|
38
38
|
|
|
39
39
|
|
|
@@ -147,7 +147,7 @@ class EfficientLoFTRRotaryEmbedding(nn.Module):
|
|
|
147
147
|
embed_height = (feats_height - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
|
|
148
148
|
embed_width = (feats_width - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
|
|
149
149
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
150
|
-
with
|
|
150
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
151
151
|
emb = compute_embeddings(self.inv_freq, embed_height, embed_width, self.config.hidden_size)
|
|
152
152
|
sin = emb.sin()
|
|
153
153
|
cos = emb.cos()
|
|
@@ -471,6 +471,7 @@ class EfficientNetModel(EfficientNetPreTrainedModel):
|
|
|
471
471
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
472
472
|
output_hidden_states: Optional[bool] = None,
|
|
473
473
|
return_dict: Optional[bool] = None,
|
|
474
|
+
**kwargs,
|
|
474
475
|
) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
|
|
475
476
|
output_hidden_states = (
|
|
476
477
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -529,6 +530,7 @@ class EfficientNetForImageClassification(EfficientNetPreTrainedModel):
|
|
|
529
530
|
labels: Optional[torch.LongTensor] = None,
|
|
530
531
|
output_hidden_states: Optional[bool] = None,
|
|
531
532
|
return_dict: Optional[bool] = None,
|
|
533
|
+
**kwargs,
|
|
532
534
|
) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
|
|
533
535
|
r"""
|
|
534
536
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -33,7 +33,7 @@ from ... import initialization as init
|
|
|
33
33
|
from ...activations import ACT2FN
|
|
34
34
|
from ...cache_utils import Cache, DynamicCache
|
|
35
35
|
from ...generation import GenerationMixin
|
|
36
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
36
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
37
37
|
from ...masking_utils import create_causal_mask
|
|
38
38
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
39
39
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -41,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
41
41
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
42
42
|
from ...processing_utils import Unpack
|
|
43
43
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
44
|
-
from ...utils.generic import check_model_inputs
|
|
44
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
45
45
|
from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
|
|
46
46
|
|
|
47
47
|
|
|
@@ -118,6 +118,7 @@ def eager_attention_forward(
|
|
|
118
118
|
return attn_output, attn_weights
|
|
119
119
|
|
|
120
120
|
|
|
121
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
121
122
|
class Emu3Attention(nn.Module):
|
|
122
123
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
123
124
|
|
|
@@ -143,7 +144,6 @@ class Emu3Attention(nn.Module):
|
|
|
143
144
|
self.o_proj = nn.Linear(
|
|
144
145
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
145
146
|
)
|
|
146
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
147
147
|
|
|
148
148
|
def forward(
|
|
149
149
|
self,
|
|
@@ -1167,7 +1167,7 @@ class Emu3RotaryEmbedding(nn.Module):
|
|
|
1167
1167
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
1168
1168
|
|
|
1169
1169
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
1170
|
-
with
|
|
1170
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
1171
1171
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
1172
1172
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1173
1173
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -815,7 +815,19 @@ class EomtImageProcessor(BaseImageProcessor):
|
|
|
815
815
|
|
|
816
816
|
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
|
817
817
|
|
|
818
|
-
|
|
818
|
+
if patch_offsets:
|
|
819
|
+
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
|
|
820
|
+
else:
|
|
821
|
+
output_logits = []
|
|
822
|
+
|
|
823
|
+
for idx in range(len(segmentation_logits)):
|
|
824
|
+
resized_logits = torch.nn.functional.interpolate(
|
|
825
|
+
segmentation_logits[idx].unsqueeze(dim=0),
|
|
826
|
+
size=target_sizes[idx],
|
|
827
|
+
mode="bilinear",
|
|
828
|
+
align_corners=False,
|
|
829
|
+
)
|
|
830
|
+
output_logits.append(resized_logits[0])
|
|
819
831
|
|
|
820
832
|
preds = [logit.argmax(dim=0) for logit in output_logits]
|
|
821
833
|
return preds
|
|
@@ -239,7 +239,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|
|
239
239
|
for shape, stacked_images in grouped_images.items():
|
|
240
240
|
if do_resize:
|
|
241
241
|
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
|
242
|
-
|
|
242
|
+
resized_images_grouped[shape] = stacked_images
|
|
243
243
|
images = reorder_images(resized_images_grouped, grouped_images_index)
|
|
244
244
|
|
|
245
245
|
# Group images by size for batched resizing, Needed in case do_resize is False.
|
|
@@ -385,7 +385,19 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|
|
385
385
|
|
|
386
386
|
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
|
387
387
|
|
|
388
|
-
|
|
388
|
+
if patch_offsets:
|
|
389
|
+
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
|
|
390
|
+
else:
|
|
391
|
+
output_logits = []
|
|
392
|
+
|
|
393
|
+
for idx in range(len(segmentation_logits)):
|
|
394
|
+
resized_logits = torch.nn.functional.interpolate(
|
|
395
|
+
segmentation_logits[idx].unsqueeze(dim=0),
|
|
396
|
+
size=target_sizes[idx],
|
|
397
|
+
mode="bilinear",
|
|
398
|
+
align_corners=False,
|
|
399
|
+
)
|
|
400
|
+
output_logits.append(resized_logits[0])
|
|
389
401
|
|
|
390
402
|
preds = [logit.argmax(dim=0) for logit in output_logits]
|
|
391
403
|
return preds
|
|
@@ -27,7 +27,7 @@ from torch import nn
|
|
|
27
27
|
from ...activations import ACT2FN
|
|
28
28
|
from ...cache_utils import Cache, DynamicCache
|
|
29
29
|
from ...generation import GenerationMixin
|
|
30
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
30
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
31
31
|
from ...masking_utils import create_causal_mask
|
|
32
32
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
33
33
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -35,7 +35,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
35
35
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
36
36
|
from ...processing_utils import Unpack
|
|
37
37
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
38
|
-
from ...utils.generic import check_model_inputs
|
|
38
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
39
39
|
from .configuration_ernie4_5 import Ernie4_5Config
|
|
40
40
|
|
|
41
41
|
|
|
@@ -95,7 +95,7 @@ class Ernie4_5RotaryEmbedding(nn.Module):
|
|
|
95
95
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
96
96
|
|
|
97
97
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
98
|
-
with
|
|
98
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
99
99
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
100
100
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
101
101
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -203,6 +203,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
203
203
|
return q_embed.to(original_dtype), k_embed.to(original_dtype)
|
|
204
204
|
|
|
205
205
|
|
|
206
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
206
207
|
class Ernie4_5Attention(nn.Module):
|
|
207
208
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
208
209
|
|
|
@@ -221,7 +222,6 @@ class Ernie4_5Attention(nn.Module):
|
|
|
221
222
|
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
|
|
222
223
|
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
|
|
223
224
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
|
|
224
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
225
225
|
|
|
226
226
|
def forward(
|
|
227
227
|
self,
|
|
@@ -18,6 +18,7 @@ from torch import nn
|
|
|
18
18
|
|
|
19
19
|
from ...modeling_rope_utils import dynamic_rope_update
|
|
20
20
|
from ...utils import auto_docstring, can_return_tuple
|
|
21
|
+
from ...utils.generic import maybe_autocast
|
|
21
22
|
from ..glm.modeling_glm import rotate_half
|
|
22
23
|
from ..llama.modeling_llama import (
|
|
23
24
|
LlamaAttention,
|
|
@@ -36,7 +37,7 @@ class Ernie4_5RotaryEmbedding(OlmoRotaryEmbedding):
|
|
|
36
37
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
37
38
|
|
|
38
39
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
39
|
-
with
|
|
40
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
40
41
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
41
42
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
42
43
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -29,7 +29,7 @@ from ... import initialization as init
|
|
|
29
29
|
from ...activations import ACT2FN
|
|
30
30
|
from ...cache_utils import Cache, DynamicCache
|
|
31
31
|
from ...generation import GenerationMixin
|
|
32
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
32
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
33
33
|
from ...masking_utils import create_causal_mask
|
|
34
34
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
35
35
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
@@ -37,7 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
37
37
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
38
38
|
from ...processing_utils import Unpack
|
|
39
39
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
40
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
40
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
41
41
|
from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
|
|
42
42
|
|
|
43
43
|
|
|
@@ -135,7 +135,7 @@ class Ernie4_5_MoeRotaryEmbedding(nn.Module):
|
|
|
135
135
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
136
136
|
|
|
137
137
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
138
|
-
with
|
|
138
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
139
139
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
140
140
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
141
141
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -226,6 +226,7 @@ def eager_attention_forward(
|
|
|
226
226
|
return attn_output, attn_weights
|
|
227
227
|
|
|
228
228
|
|
|
229
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
229
230
|
class Ernie4_5_MoeAttention(nn.Module):
|
|
230
231
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
231
232
|
|
|
@@ -244,7 +245,6 @@ class Ernie4_5_MoeAttention(nn.Module):
|
|
|
244
245
|
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
|
|
245
246
|
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
|
|
246
247
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
|
|
247
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
248
248
|
|
|
249
249
|
def forward(
|
|
250
250
|
self,
|
|
@@ -371,7 +371,7 @@ class Ernie4_5_MoeTopKRouter(nn.Module):
|
|
|
371
371
|
else "cpu"
|
|
372
372
|
)
|
|
373
373
|
|
|
374
|
-
with
|
|
374
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
375
375
|
router_logits = F.linear(hidden_states.float(), self.weight)
|
|
376
376
|
router_logits = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
377
377
|
router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1)
|
|
@@ -26,7 +26,7 @@ from ...modeling_outputs import MoeModelOutputWithPast
|
|
|
26
26
|
from ...modeling_utils import PreTrainedModel
|
|
27
27
|
from ...processing_utils import Unpack
|
|
28
28
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
29
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
29
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
30
30
|
from ..ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding, apply_rotary_pos_emb, rotate_half # noqa: F401
|
|
31
31
|
from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm
|
|
32
32
|
from ..mixtral.modeling_mixtral import (
|
|
@@ -146,7 +146,7 @@ class Ernie4_5_MoeTopKRouter(nn.Module):
|
|
|
146
146
|
else "cpu"
|
|
147
147
|
)
|
|
148
148
|
|
|
149
|
-
with
|
|
149
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
150
150
|
router_logits = F.linear(hidden_states.float(), self.weight)
|
|
151
151
|
router_logits = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
152
152
|
router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1)
|
|
@@ -32,6 +32,7 @@ from ...utils import (
|
|
|
32
32
|
auto_docstring,
|
|
33
33
|
logging,
|
|
34
34
|
)
|
|
35
|
+
from ...utils.generic import maybe_autocast
|
|
35
36
|
from .modeling_esm import EsmModel, EsmPreTrainedModel
|
|
36
37
|
from .openfold_utils import (
|
|
37
38
|
OFProtein,
|
|
@@ -267,7 +268,7 @@ class EsmFoldLayerNorm(nn.Module):
|
|
|
267
268
|
def forward(self, x):
|
|
268
269
|
d = x.dtype
|
|
269
270
|
if d is torch.bfloat16 and not is_deepspeed_initialized():
|
|
270
|
-
with
|
|
271
|
+
with maybe_autocast(device_type="cuda", enabled=False):
|
|
271
272
|
out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps)
|
|
272
273
|
else:
|
|
273
274
|
out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps)
|
|
@@ -282,7 +283,7 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
|
282
283
|
"""
|
|
283
284
|
d = t.dtype
|
|
284
285
|
if d is torch.bfloat16 and not is_deepspeed_initialized():
|
|
285
|
-
with
|
|
286
|
+
with maybe_autocast(device_type="cuda", enabled=False):
|
|
286
287
|
s = torch.nn.functional.softmax(t, dim=dim)
|
|
287
288
|
else:
|
|
288
289
|
s = torch.nn.functional.softmax(t, dim=dim)
|
|
@@ -868,7 +869,7 @@ class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
|
|
|
868
869
|
|
|
869
870
|
device_type = a.device.type if a.device.type != "mps" else "cpu"
|
|
870
871
|
if is_fp16_enabled(device_type):
|
|
871
|
-
with
|
|
872
|
+
with maybe_autocast(device_type=device_type, enabled=False):
|
|
872
873
|
x = self._combine_projections(a.float(), b.float())
|
|
873
874
|
else:
|
|
874
875
|
x = self._combine_projections(a, b)
|
|
@@ -1491,7 +1492,7 @@ class EsmFoldInvariantPointAttention(nn.Module):
|
|
|
1491
1492
|
# [*, H, N_res, N_res]
|
|
1492
1493
|
device_type = q.device.type if q.device.type != "mps" else "cpu"
|
|
1493
1494
|
if is_fp16_enabled(device_type):
|
|
1494
|
-
with
|
|
1495
|
+
with maybe_autocast(device_type=device_type, enabled=False):
|
|
1495
1496
|
a = torch.matmul(
|
|
1496
1497
|
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
|
|
1497
1498
|
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
|