transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -39,7 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
39
39
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
40
40
|
from ...processing_utils import Unpack
|
|
41
41
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
42
|
-
from ...utils.generic import check_model_inputs
|
|
42
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
43
43
|
from .configuration_modernbert_decoder import ModernBertDecoderConfig
|
|
44
44
|
|
|
45
45
|
|
|
@@ -168,7 +168,7 @@ class ModernBertDecoderRotaryEmbedding(nn.Module):
|
|
|
168
168
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
169
169
|
|
|
170
170
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
171
|
-
with
|
|
171
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
172
172
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
173
173
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
174
174
|
cos = emb.cos() * attention_scaling
|
|
@@ -342,7 +342,7 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer):
|
|
|
342
342
|
attention_mask: Optional[torch.Tensor] = None,
|
|
343
343
|
past_key_values: Optional[Cache] = None,
|
|
344
344
|
cache_position: Optional[torch.LongTensor] = None,
|
|
345
|
-
**kwargs,
|
|
345
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
346
346
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
347
347
|
residual = hidden_states
|
|
348
348
|
hidden_states = self.attn_norm(hidden_states)
|
|
@@ -477,7 +477,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
|
|
477
477
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
478
478
|
use_cache: Optional[bool] = None,
|
|
479
479
|
cache_position: Optional[torch.LongTensor] = None,
|
|
480
|
-
**kwargs,
|
|
480
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
481
481
|
) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
|
|
482
482
|
if (input_ids is None) == (inputs_embeds is None):
|
|
483
483
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
@@ -489,7 +489,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
|
|
489
489
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
|
490
490
|
|
|
491
491
|
# Handle past_key_values and cache setup
|
|
492
|
-
if use_cache and past_key_values is None
|
|
492
|
+
if use_cache and past_key_values is None:
|
|
493
493
|
past_key_values = DynamicCache(config=self.config)
|
|
494
494
|
|
|
495
495
|
if cache_position is None:
|
|
@@ -527,13 +527,12 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
|
|
527
527
|
for layer_type in self.config.layer_types:
|
|
528
528
|
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
|
529
529
|
|
|
530
|
-
for
|
|
530
|
+
for decoder_layer in self.layers:
|
|
531
531
|
hidden_states = decoder_layer(
|
|
532
532
|
hidden_states,
|
|
533
533
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
|
534
534
|
position_embeddings=position_embeddings[decoder_layer.attention_type],
|
|
535
535
|
past_key_values=past_key_values,
|
|
536
|
-
use_cache=use_cache,
|
|
537
536
|
cache_position=cache_position,
|
|
538
537
|
position_ids=position_ids,
|
|
539
538
|
**kwargs,
|
|
@@ -583,7 +582,7 @@ class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationM
|
|
|
583
582
|
labels: Optional[torch.LongTensor] = None,
|
|
584
583
|
use_cache: Optional[bool] = None,
|
|
585
584
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
586
|
-
**kwargs,
|
|
585
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
587
586
|
) -> Union[tuple, CausalLMOutputWithPast]:
|
|
588
587
|
r"""
|
|
589
588
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -686,7 +685,7 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
|
|
|
686
685
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
687
686
|
labels: Optional[torch.LongTensor] = None,
|
|
688
687
|
use_cache: Optional[bool] = None,
|
|
689
|
-
**kwargs,
|
|
688
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
690
689
|
) -> Union[tuple, SequenceClassifierOutputWithPast]:
|
|
691
690
|
r"""
|
|
692
691
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -394,7 +394,7 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer):
|
|
|
394
394
|
attention_mask: Optional[torch.Tensor] = None,
|
|
395
395
|
past_key_values: Optional[Cache] = None,
|
|
396
396
|
cache_position: Optional[torch.LongTensor] = None,
|
|
397
|
-
**kwargs,
|
|
397
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
398
398
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
399
399
|
residual = hidden_states
|
|
400
400
|
hidden_states = self.attn_norm(hidden_states)
|
|
@@ -525,7 +525,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
|
|
525
525
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
526
526
|
use_cache: Optional[bool] = None,
|
|
527
527
|
cache_position: Optional[torch.LongTensor] = None,
|
|
528
|
-
**kwargs,
|
|
528
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
529
529
|
) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
|
|
530
530
|
if (input_ids is None) == (inputs_embeds is None):
|
|
531
531
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
@@ -537,7 +537,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
|
|
537
537
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
|
538
538
|
|
|
539
539
|
# Handle past_key_values and cache setup
|
|
540
|
-
if use_cache and past_key_values is None
|
|
540
|
+
if use_cache and past_key_values is None:
|
|
541
541
|
past_key_values = DynamicCache(config=self.config)
|
|
542
542
|
|
|
543
543
|
if cache_position is None:
|
|
@@ -575,13 +575,12 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
|
|
575
575
|
for layer_type in self.config.layer_types:
|
|
576
576
|
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
|
577
577
|
|
|
578
|
-
for
|
|
578
|
+
for decoder_layer in self.layers:
|
|
579
579
|
hidden_states = decoder_layer(
|
|
580
580
|
hidden_states,
|
|
581
581
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
|
582
582
|
position_embeddings=position_embeddings[decoder_layer.attention_type],
|
|
583
583
|
past_key_values=past_key_values,
|
|
584
|
-
use_cache=use_cache,
|
|
585
584
|
cache_position=cache_position,
|
|
586
585
|
position_ids=position_ids,
|
|
587
586
|
**kwargs,
|
|
@@ -631,7 +630,7 @@ class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationM
|
|
|
631
630
|
labels: Optional[torch.LongTensor] = None,
|
|
632
631
|
use_cache: Optional[bool] = None,
|
|
633
632
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
634
|
-
**kwargs,
|
|
633
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
635
634
|
) -> Union[tuple, CausalLMOutputWithPast]:
|
|
636
635
|
r"""
|
|
637
636
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -734,7 +733,7 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
|
|
|
734
733
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
735
734
|
labels: Optional[torch.LongTensor] = None,
|
|
736
735
|
use_cache: Optional[bool] = None,
|
|
737
|
-
**kwargs,
|
|
736
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
738
737
|
) -> Union[tuple, SequenceClassifierOutputWithPast]:
|
|
739
738
|
r"""
|
|
740
739
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -30,6 +30,7 @@ from transformers.utils.generic import OutputRecorder, check_model_inputs
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
+
from ...integrations import use_kernelized_func
|
|
33
34
|
from ...masking_utils import create_causal_mask
|
|
34
35
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
|
35
36
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
@@ -45,6 +46,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
45
46
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
46
47
|
from ...processing_utils import Unpack
|
|
47
48
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
49
|
+
from ...utils.generic import maybe_autocast
|
|
48
50
|
from .configuration_moonshine import MoonshineConfig
|
|
49
51
|
|
|
50
52
|
|
|
@@ -137,7 +139,7 @@ class MoonshineRotaryEmbedding(nn.Module):
|
|
|
137
139
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
138
140
|
|
|
139
141
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
140
|
-
with
|
|
142
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
141
143
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
142
144
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
143
145
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -233,6 +235,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
233
235
|
return q_embed, k_embed
|
|
234
236
|
|
|
235
237
|
|
|
238
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
236
239
|
class MoonshineAttention(nn.Module):
|
|
237
240
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
238
241
|
|
|
@@ -264,7 +267,6 @@ class MoonshineAttention(nn.Module):
|
|
|
264
267
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
265
268
|
)
|
|
266
269
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
|
267
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
268
270
|
|
|
269
271
|
# Pad head dimension to the next specified multiple.
|
|
270
272
|
if self.config.pad_head_dim_to_multiple_of is not None:
|
|
@@ -34,6 +34,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast,
|
|
|
34
34
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
35
35
|
from ...modeling_utils import PreTrainedModel
|
|
36
36
|
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
|
37
|
+
from ...utils.generic import maybe_autocast
|
|
37
38
|
from ..auto.modeling_auto import AutoModel
|
|
38
39
|
from .configuration_moshi import MoshiConfig, MoshiDepthConfig
|
|
39
40
|
|
|
@@ -327,7 +328,7 @@ class MoshiRotaryEmbedding(nn.Module):
|
|
|
327
328
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
328
329
|
|
|
329
330
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
330
|
-
with
|
|
331
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
331
332
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
332
333
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
333
334
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -882,6 +883,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
|
|
|
882
883
|
position_ids: Optional[torch.LongTensor] = None,
|
|
883
884
|
labels: Optional[torch.LongTensor] = None,
|
|
884
885
|
cache_position: Optional[torch.LongTensor] = None,
|
|
886
|
+
**kwargs,
|
|
885
887
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
886
888
|
"""
|
|
887
889
|
Args:
|
|
@@ -957,7 +959,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
|
|
|
957
959
|
)
|
|
958
960
|
use_cache = False
|
|
959
961
|
|
|
960
|
-
if use_cache and past_key_values is None
|
|
962
|
+
if use_cache and past_key_values is None:
|
|
961
963
|
past_key_values = DynamicCache(config=self.config)
|
|
962
964
|
|
|
963
965
|
past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
@@ -1228,6 +1230,7 @@ class MoshiModel(MoshiPreTrainedModel):
|
|
|
1228
1230
|
output_hidden_states: Optional[bool] = None,
|
|
1229
1231
|
return_dict: Optional[bool] = None,
|
|
1230
1232
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1233
|
+
**kwargs,
|
|
1231
1234
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
1232
1235
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1233
1236
|
output_hidden_states = (
|
|
@@ -488,6 +488,7 @@ class MPNetForMaskedLM(MPNetPreTrainedModel):
|
|
|
488
488
|
output_attentions: Optional[bool] = None,
|
|
489
489
|
output_hidden_states: Optional[bool] = None,
|
|
490
490
|
return_dict: Optional[bool] = None,
|
|
491
|
+
**kwargs,
|
|
491
492
|
) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
|
|
492
493
|
r"""
|
|
493
494
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -577,6 +578,7 @@ class MPNetForSequenceClassification(MPNetPreTrainedModel):
|
|
|
577
578
|
output_attentions: Optional[bool] = None,
|
|
578
579
|
output_hidden_states: Optional[bool] = None,
|
|
579
580
|
return_dict: Optional[bool] = None,
|
|
581
|
+
**kwargs,
|
|
580
582
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
|
|
581
583
|
r"""
|
|
582
584
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -656,6 +658,7 @@ class MPNetForMultipleChoice(MPNetPreTrainedModel):
|
|
|
656
658
|
output_attentions: Optional[bool] = None,
|
|
657
659
|
output_hidden_states: Optional[bool] = None,
|
|
658
660
|
return_dict: Optional[bool] = None,
|
|
661
|
+
**kwargs,
|
|
659
662
|
) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
|
660
663
|
r"""
|
|
661
664
|
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
|
@@ -748,6 +751,7 @@ class MPNetForTokenClassification(MPNetPreTrainedModel):
|
|
|
748
751
|
output_attentions: Optional[bool] = None,
|
|
749
752
|
output_hidden_states: Optional[bool] = None,
|
|
750
753
|
return_dict: Optional[bool] = None,
|
|
754
|
+
**kwargs,
|
|
751
755
|
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
|
752
756
|
r"""
|
|
753
757
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -831,6 +835,7 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel):
|
|
|
831
835
|
output_attentions: Optional[bool] = None,
|
|
832
836
|
output_hidden_states: Optional[bool] = None,
|
|
833
837
|
return_dict: Optional[bool] = None,
|
|
838
|
+
**kwargs,
|
|
834
839
|
) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
|
835
840
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
836
841
|
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
# limitations under the License.
|
|
16
16
|
"""Tokenization classes for MPNet."""
|
|
17
17
|
|
|
18
|
-
from typing import Optional
|
|
18
|
+
from typing import Optional, Union
|
|
19
19
|
|
|
20
20
|
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
|
21
21
|
from tokenizers.models import WordPiece
|
|
@@ -38,7 +38,7 @@ class MPNetTokenizer(TokenizersBackend):
|
|
|
38
38
|
refer to this superclass for more information regarding those methods.
|
|
39
39
|
|
|
40
40
|
Args:
|
|
41
|
-
vocab (`dict`, *optional*):
|
|
41
|
+
vocab (`str` or `dict[str, int]`, *optional*):
|
|
42
42
|
Dictionary mapping tokens to their IDs. If not provided, an empty vocab is initialized.
|
|
43
43
|
do_lower_case (`bool`, *optional*, defaults to `True`):
|
|
44
44
|
Whether or not to lowercase the input when tokenizing.
|
|
@@ -87,10 +87,11 @@ class MPNetTokenizer(TokenizersBackend):
|
|
|
87
87
|
|
|
88
88
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
89
89
|
model_input_names = ["input_ids", "attention_mask"]
|
|
90
|
+
model = WordPiece
|
|
90
91
|
|
|
91
92
|
def __init__(
|
|
92
93
|
self,
|
|
93
|
-
vocab: Optional[dict] = None,
|
|
94
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
94
95
|
do_lower_case=True,
|
|
95
96
|
bos_token="<s>",
|
|
96
97
|
eos_token="</s>",
|
|
@@ -104,12 +105,7 @@ class MPNetTokenizer(TokenizersBackend):
|
|
|
104
105
|
**kwargs,
|
|
105
106
|
):
|
|
106
107
|
# Initialize vocab
|
|
107
|
-
if vocab is not None
|
|
108
|
-
self._vocab = (
|
|
109
|
-
{token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
|
|
110
|
-
)
|
|
111
|
-
else:
|
|
112
|
-
self._vocab = {}
|
|
108
|
+
self._vocab = vocab if vocab is not None else {}
|
|
113
109
|
|
|
114
110
|
# Initialize the tokenizer with WordPiece model
|
|
115
111
|
self._tokenizer = Tokenizer(WordPiece(self._vocab, unk_token=str(unk_token)))
|
|
@@ -142,11 +138,7 @@ class MPNetTokenizer(TokenizersBackend):
|
|
|
142
138
|
# Mask token behave like a normal word, i.e. include the space before it
|
|
143
139
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
|
144
140
|
|
|
145
|
-
# Store for later use
|
|
146
|
-
tokenizer_object = self._tokenizer
|
|
147
|
-
|
|
148
141
|
super().__init__(
|
|
149
|
-
tokenizer_object=tokenizer_object,
|
|
150
142
|
do_lower_case=do_lower_case,
|
|
151
143
|
bos_token=bos_token,
|
|
152
144
|
eos_token=eos_token,
|
|
@@ -498,6 +498,7 @@ class MptForSequenceClassification(MptPreTrainedModel):
|
|
|
498
498
|
output_attentions: Optional[bool] = None,
|
|
499
499
|
output_hidden_states: Optional[bool] = None,
|
|
500
500
|
return_dict: Optional[bool] = None,
|
|
501
|
+
**kwargs,
|
|
501
502
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
|
|
502
503
|
r"""
|
|
503
504
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -700,6 +701,7 @@ class MptForQuestionAnswering(MptPreTrainedModel):
|
|
|
700
701
|
output_attentions: Optional[bool] = None,
|
|
701
702
|
output_hidden_states: Optional[bool] = None,
|
|
702
703
|
return_dict: Optional[bool] = None,
|
|
704
|
+
**kwargs,
|
|
703
705
|
) -> Union[tuple, QuestionAnsweringModelOutput]:
|
|
704
706
|
r"""
|
|
705
707
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -826,6 +826,7 @@ class MraModel(MraPreTrainedModel):
|
|
|
826
826
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
827
827
|
output_hidden_states: Optional[bool] = None,
|
|
828
828
|
return_dict: Optional[bool] = None,
|
|
829
|
+
**kwargs,
|
|
829
830
|
) -> Union[tuple, BaseModelOutputWithCrossAttentions]:
|
|
830
831
|
output_hidden_states = (
|
|
831
832
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -919,6 +920,7 @@ class MraForMaskedLM(MraPreTrainedModel):
|
|
|
919
920
|
labels: Optional[torch.Tensor] = None,
|
|
920
921
|
output_hidden_states: Optional[bool] = None,
|
|
921
922
|
return_dict: Optional[bool] = None,
|
|
923
|
+
**kwargs,
|
|
922
924
|
) -> Union[tuple, MaskedLMOutput]:
|
|
923
925
|
r"""
|
|
924
926
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1007,6 +1009,7 @@ class MraForSequenceClassification(MraPreTrainedModel):
|
|
|
1007
1009
|
labels: Optional[torch.Tensor] = None,
|
|
1008
1010
|
output_hidden_states: Optional[bool] = None,
|
|
1009
1011
|
return_dict: Optional[bool] = None,
|
|
1012
|
+
**kwargs,
|
|
1010
1013
|
) -> Union[tuple, SequenceClassifierOutput]:
|
|
1011
1014
|
r"""
|
|
1012
1015
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -1086,6 +1089,7 @@ class MraForMultipleChoice(MraPreTrainedModel):
|
|
|
1086
1089
|
labels: Optional[torch.Tensor] = None,
|
|
1087
1090
|
output_hidden_states: Optional[bool] = None,
|
|
1088
1091
|
return_dict: Optional[bool] = None,
|
|
1092
|
+
**kwargs,
|
|
1089
1093
|
) -> Union[tuple, MultipleChoiceModelOutput]:
|
|
1090
1094
|
r"""
|
|
1091
1095
|
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
|
@@ -1189,6 +1193,7 @@ class MraForTokenClassification(MraPreTrainedModel):
|
|
|
1189
1193
|
labels: Optional[torch.Tensor] = None,
|
|
1190
1194
|
output_hidden_states: Optional[bool] = None,
|
|
1191
1195
|
return_dict: Optional[bool] = None,
|
|
1196
|
+
**kwargs,
|
|
1192
1197
|
) -> Union[tuple, TokenClassifierOutput]:
|
|
1193
1198
|
r"""
|
|
1194
1199
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1263,6 +1268,7 @@ class MraForQuestionAnswering(MraPreTrainedModel):
|
|
|
1263
1268
|
end_positions: Optional[torch.Tensor] = None,
|
|
1264
1269
|
output_hidden_states: Optional[bool] = None,
|
|
1265
1270
|
return_dict: Optional[bool] = None,
|
|
1271
|
+
**kwargs,
|
|
1266
1272
|
) -> Union[tuple, QuestionAnsweringModelOutput]:
|
|
1267
1273
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1268
1274
|
|
|
@@ -671,6 +671,7 @@ class MT5Stack(MT5PreTrainedModel):
|
|
|
671
671
|
output_hidden_states=None,
|
|
672
672
|
return_dict=None,
|
|
673
673
|
cache_position=None,
|
|
674
|
+
**kwargs,
|
|
674
675
|
):
|
|
675
676
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
676
677
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
@@ -898,6 +899,7 @@ class MT5Model(MT5PreTrainedModel):
|
|
|
898
899
|
output_hidden_states: Optional[bool] = None,
|
|
899
900
|
return_dict: Optional[bool] = None,
|
|
900
901
|
cache_position: Optional[torch.LongTensor] = None,
|
|
902
|
+
**kwargs,
|
|
901
903
|
) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
|
|
902
904
|
r"""
|
|
903
905
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -1081,6 +1083,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel, GenerationMixin):
|
|
|
1081
1083
|
output_hidden_states: Optional[bool] = None,
|
|
1082
1084
|
return_dict: Optional[bool] = None,
|
|
1083
1085
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1086
|
+
**kwargs,
|
|
1084
1087
|
) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
|
1085
1088
|
r"""
|
|
1086
1089
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -1268,6 +1271,7 @@ class MT5EncoderModel(MT5PreTrainedModel):
|
|
|
1268
1271
|
output_attentions: Optional[bool] = None,
|
|
1269
1272
|
output_hidden_states: Optional[bool] = None,
|
|
1270
1273
|
return_dict: Optional[bool] = None,
|
|
1274
|
+
**kwargs,
|
|
1271
1275
|
) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
|
|
1272
1276
|
r"""
|
|
1273
1277
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -1340,6 +1344,7 @@ class MT5ForSequenceClassification(MT5PreTrainedModel):
|
|
|
1340
1344
|
output_attentions: Optional[bool] = None,
|
|
1341
1345
|
output_hidden_states: Optional[bool] = None,
|
|
1342
1346
|
return_dict: Optional[bool] = None,
|
|
1347
|
+
**kwargs,
|
|
1343
1348
|
) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
|
|
1344
1349
|
r"""
|
|
1345
1350
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -1480,6 +1485,7 @@ class MT5ForTokenClassification(MT5PreTrainedModel):
|
|
|
1480
1485
|
output_attentions: Optional[bool] = None,
|
|
1481
1486
|
output_hidden_states: Optional[bool] = None,
|
|
1482
1487
|
return_dict: Optional[bool] = None,
|
|
1488
|
+
**kwargs,
|
|
1483
1489
|
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
|
1484
1490
|
r"""
|
|
1485
1491
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -1587,6 +1593,7 @@ class MT5ForQuestionAnswering(MT5PreTrainedModel):
|
|
|
1587
1593
|
output_attentions: Optional[bool] = None,
|
|
1588
1594
|
output_hidden_states: Optional[bool] = None,
|
|
1589
1595
|
return_dict: Optional[bool] = None,
|
|
1596
|
+
**kwargs,
|
|
1590
1597
|
) -> Union[tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
|
|
1591
1598
|
r"""
|
|
1592
1599
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -482,6 +482,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
|
|
|
482
482
|
output_hidden_states: Optional[bool] = None,
|
|
483
483
|
return_dict: Optional[bool] = None,
|
|
484
484
|
cache_position: Optional[torch.Tensor] = None,
|
|
485
|
+
**kwargs,
|
|
485
486
|
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
486
487
|
r"""
|
|
487
488
|
input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
|
|
@@ -716,6 +717,7 @@ class MusicgenModel(MusicgenPreTrainedModel):
|
|
|
716
717
|
output_hidden_states: Optional[bool] = None,
|
|
717
718
|
return_dict: Optional[bool] = None,
|
|
718
719
|
cache_position: Optional[torch.Tensor] = None,
|
|
720
|
+
**kwargs,
|
|
719
721
|
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
720
722
|
r"""
|
|
721
723
|
input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
|
|
@@ -455,6 +455,7 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel):
|
|
|
455
455
|
output_hidden_states: Optional[bool] = None,
|
|
456
456
|
return_dict: Optional[bool] = None,
|
|
457
457
|
cache_position: Optional[torch.Tensor] = None,
|
|
458
|
+
**kwargs,
|
|
458
459
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
459
460
|
r"""
|
|
460
461
|
input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
|
|
@@ -670,6 +671,7 @@ class MusicgenMelodyModel(MusicgenMelodyPreTrainedModel):
|
|
|
670
671
|
output_hidden_states: Optional[bool] = None,
|
|
671
672
|
return_dict: Optional[bool] = None,
|
|
672
673
|
cache_position: Optional[torch.Tensor] = None,
|
|
674
|
+
**kwargs,
|
|
673
675
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
674
676
|
r"""
|
|
675
677
|
input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
|
|
@@ -785,6 +787,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin):
|
|
|
785
787
|
return_dict: Optional[bool] = None,
|
|
786
788
|
labels: Optional[torch.LongTensor] = None,
|
|
787
789
|
cache_position: Optional[torch.Tensor] = None,
|
|
790
|
+
**kwargs,
|
|
788
791
|
) -> Union[tuple, MusicgenMelodyOutputWithPast]:
|
|
789
792
|
r"""
|
|
790
793
|
input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
|
|
@@ -534,6 +534,7 @@ class MvpEncoder(MvpPreTrainedModel):
|
|
|
534
534
|
output_attentions: Optional[bool] = None,
|
|
535
535
|
output_hidden_states: Optional[bool] = None,
|
|
536
536
|
return_dict: Optional[bool] = None,
|
|
537
|
+
**kwargs,
|
|
537
538
|
) -> Union[tuple, BaseModelOutput]:
|
|
538
539
|
r"""
|
|
539
540
|
Args:
|
|
@@ -698,6 +699,7 @@ class MvpDecoder(MvpPreTrainedModel):
|
|
|
698
699
|
output_hidden_states: Optional[bool] = None,
|
|
699
700
|
return_dict: Optional[bool] = None,
|
|
700
701
|
cache_position: Optional[torch.Tensor] = None,
|
|
702
|
+
**kwargs,
|
|
701
703
|
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
702
704
|
r"""
|
|
703
705
|
Args:
|
|
@@ -917,6 +919,7 @@ class MvpModel(MvpPreTrainedModel):
|
|
|
917
919
|
output_hidden_states: Optional[bool] = None,
|
|
918
920
|
return_dict: Optional[bool] = None,
|
|
919
921
|
cache_position: Optional[torch.Tensor] = None,
|
|
922
|
+
**kwargs,
|
|
920
923
|
) -> Union[tuple, Seq2SeqModelOutput]:
|
|
921
924
|
r"""
|
|
922
925
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1065,6 +1068,7 @@ class MvpForConditionalGeneration(MvpPreTrainedModel, GenerationMixin):
|
|
|
1065
1068
|
output_hidden_states: Optional[bool] = None,
|
|
1066
1069
|
return_dict: Optional[bool] = None,
|
|
1067
1070
|
cache_position: Optional[torch.Tensor] = None,
|
|
1071
|
+
**kwargs,
|
|
1068
1072
|
) -> Union[tuple, Seq2SeqLMOutput]:
|
|
1069
1073
|
r"""
|
|
1070
1074
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1213,6 +1217,7 @@ class MvpForSequenceClassification(MvpPreTrainedModel):
|
|
|
1213
1217
|
output_attentions: Optional[bool] = None,
|
|
1214
1218
|
output_hidden_states: Optional[bool] = None,
|
|
1215
1219
|
return_dict: Optional[bool] = None,
|
|
1220
|
+
**kwargs,
|
|
1216
1221
|
) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
|
|
1217
1222
|
r"""
|
|
1218
1223
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1372,6 +1377,7 @@ class MvpForQuestionAnswering(MvpPreTrainedModel):
|
|
|
1372
1377
|
output_attentions: Optional[bool] = None,
|
|
1373
1378
|
output_hidden_states: Optional[bool] = None,
|
|
1374
1379
|
return_dict: Optional[bool] = None,
|
|
1380
|
+
**kwargs,
|
|
1375
1381
|
) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]:
|
|
1376
1382
|
r"""
|
|
1377
1383
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1548,6 +1554,7 @@ class MvpForCausalLM(MvpPreTrainedModel, GenerationMixin):
|
|
|
1548
1554
|
return_dict: Optional[bool] = None,
|
|
1549
1555
|
cache_position: Optional[torch.Tensor] = None,
|
|
1550
1556
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1557
|
+
**kwargs,
|
|
1551
1558
|
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
|
|
1552
1559
|
r"""
|
|
1553
1560
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -30,7 +30,7 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_func_from_hub
|
|
33
|
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask
|
|
35
35
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
36
36
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
38
38
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
40
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
41
|
-
from ...utils.generic import check_model_inputs
|
|
41
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
42
42
|
from .configuration_nanochat import NanoChatConfig
|
|
43
43
|
|
|
44
44
|
|
|
@@ -113,7 +113,7 @@ class NanoChatRotaryEmbedding(nn.Module):
|
|
|
113
113
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
114
114
|
|
|
115
115
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
116
|
-
with
|
|
116
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
117
117
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
118
118
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
119
119
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -195,6 +195,7 @@ def rotate_half(x):
|
|
|
195
195
|
return torch.cat((x2, -x1), dim=-1)
|
|
196
196
|
|
|
197
197
|
|
|
198
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
198
199
|
class NanoChatAttention(nn.Module):
|
|
199
200
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
200
201
|
|
|
@@ -220,7 +221,6 @@ class NanoChatAttention(nn.Module):
|
|
|
220
221
|
self.o_proj = nn.Linear(
|
|
221
222
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
222
223
|
)
|
|
223
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
224
224
|
|
|
225
225
|
self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
|
226
226
|
self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
|
@@ -45,6 +45,7 @@ from ...modeling_rope_utils import (
|
|
|
45
45
|
)
|
|
46
46
|
from ...modeling_utils import PreTrainedModel
|
|
47
47
|
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
|
48
|
+
from ...utils.generic import maybe_autocast
|
|
48
49
|
from .configuration_nemotron import NemotronConfig
|
|
49
50
|
|
|
50
51
|
|
|
@@ -87,7 +88,7 @@ class NemotronLayerNorm1P(nn.LayerNorm):
|
|
|
87
88
|
args = _cast_if_autocast_enabled(
|
|
88
89
|
device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps
|
|
89
90
|
)
|
|
90
|
-
with
|
|
91
|
+
with maybe_autocast(device_type=input.device.type, enabled=False):
|
|
91
92
|
return F.layer_norm(*args)
|
|
92
93
|
|
|
93
94
|
|
|
@@ -151,7 +152,7 @@ class NemotronRotaryEmbedding(nn.Module):
|
|
|
151
152
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
152
153
|
|
|
153
154
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
154
|
-
with
|
|
155
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
155
156
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
156
157
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
157
158
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -657,6 +658,7 @@ class NemotronModel(NemotronPreTrainedModel):
|
|
|
657
658
|
output_attentions: Optional[bool] = None,
|
|
658
659
|
output_hidden_states: Optional[bool] = None,
|
|
659
660
|
cache_position: Optional[torch.LongTensor] = None,
|
|
661
|
+
**kwargs,
|
|
660
662
|
) -> BaseModelOutputWithPast:
|
|
661
663
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
662
664
|
output_hidden_states = (
|