transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -14,12 +14,12 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Tokenization classes for OpenAI GPT."""
|
|
16
16
|
|
|
17
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
from tokenizers import Tokenizer, decoders, pre_tokenizers
|
|
20
20
|
from tokenizers.models import BPE
|
|
21
21
|
|
|
22
|
-
from ...tokenization_utils_tokenizers import TokenizersBackend
|
|
22
|
+
from ...tokenization_utils_tokenizers import AddedToken, TokenizersBackend
|
|
23
23
|
from ...utils import logging
|
|
24
24
|
|
|
25
25
|
|
|
@@ -84,45 +84,31 @@ class GPT2Tokenizer(TokenizersBackend):
|
|
|
84
84
|
add_bos_token (`bool`, *optional*, defaults to `False`):
|
|
85
85
|
Whether or not to add an initial beginning of sentence token to the input. This allows to treat the leading
|
|
86
86
|
word just as any other word.
|
|
87
|
-
vocab (`dict`, *optional*):
|
|
88
|
-
Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file
|
|
89
|
-
merges (`list`, *optional*):
|
|
90
|
-
Custom merges list. If not provided, merges are loaded from merges_file
|
|
87
|
+
vocab (`str` or `dict[str, int]`, *optional*):
|
|
88
|
+
Custom vocabulary dictionary. If not provided, vocabulary is loaded from `vocab_file`.
|
|
89
|
+
merges (`str` or `list[str]`, *optional*):
|
|
90
|
+
Custom merges list. If not provided, merges are loaded from `merges_file`.
|
|
91
91
|
"""
|
|
92
92
|
|
|
93
93
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
94
94
|
model_input_names = ["input_ids", "attention_mask"]
|
|
95
|
-
|
|
95
|
+
model = BPE
|
|
96
96
|
|
|
97
97
|
def __init__(
|
|
98
98
|
self,
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
99
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
100
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
101
|
+
errors: str = "replace",
|
|
102
|
+
unk_token: Union[AddedToken, str] = "<|endoftext|>",
|
|
103
|
+
bos_token: Union[AddedToken, str] = "<|endoftext|>",
|
|
104
|
+
eos_token: Union[AddedToken, str] = "<|endoftext|>",
|
|
105
|
+
pad_token: Optional[Union[AddedToken, str]] = None,
|
|
104
106
|
add_prefix_space=False,
|
|
105
|
-
add_bos_token=False,
|
|
106
|
-
vocab: Optional[dict] = None,
|
|
107
|
-
merges: Optional[list] = None,
|
|
108
107
|
**kwargs,
|
|
109
108
|
):
|
|
110
|
-
# self.add_bos_token = add_bos_token
|
|
111
|
-
|
|
112
109
|
self.add_prefix_space = add_prefix_space
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
self._vocab = (
|
|
116
|
-
{token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
|
|
117
|
-
)
|
|
118
|
-
else:
|
|
119
|
-
self._vocab = {}
|
|
120
|
-
|
|
121
|
-
if merges is not None:
|
|
122
|
-
self._merges = [tuple(merge) if isinstance(merge, list) else merge for merge in merges]
|
|
123
|
-
else:
|
|
124
|
-
self._merges = []
|
|
125
|
-
|
|
110
|
+
self._vocab = vocab if vocab is not None else {}
|
|
111
|
+
self._merges = merges or []
|
|
126
112
|
self._tokenizer = Tokenizer(
|
|
127
113
|
BPE(
|
|
128
114
|
vocab=self._vocab,
|
|
@@ -133,31 +119,17 @@ class GPT2Tokenizer(TokenizersBackend):
|
|
|
133
119
|
fuse_unk=False,
|
|
134
120
|
)
|
|
135
121
|
)
|
|
136
|
-
|
|
137
122
|
self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
|
|
138
123
|
self._tokenizer.decoder = decoders.ByteLevel()
|
|
139
|
-
|
|
140
|
-
tokenizer_object = self._tokenizer
|
|
141
|
-
|
|
142
|
-
# Set these before calling super().__init__() so the base class _post_init() can use them
|
|
143
|
-
self._add_bos_token = add_bos_token
|
|
144
|
-
self._add_eos_token = False
|
|
145
|
-
|
|
146
124
|
super().__init__(
|
|
147
|
-
tokenizer_object=tokenizer_object,
|
|
148
125
|
errors=errors,
|
|
149
126
|
unk_token=unk_token,
|
|
150
127
|
bos_token=bos_token,
|
|
151
128
|
eos_token=eos_token,
|
|
152
129
|
pad_token=pad_token,
|
|
153
130
|
add_prefix_space=add_prefix_space,
|
|
154
|
-
add_bos_token=add_bos_token,
|
|
155
131
|
**kwargs,
|
|
156
132
|
)
|
|
157
133
|
|
|
158
|
-
# Call _post_init for tokenizers created directly (not from_pretrained)
|
|
159
|
-
# For from_pretrained, this will be called again after loading the tokenizer from file
|
|
160
|
-
self._post_init()
|
|
161
|
-
|
|
162
134
|
|
|
163
135
|
__all__ = ["GPT2Tokenizer"]
|
|
@@ -826,6 +826,7 @@ class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel):
|
|
|
826
826
|
output_attentions: Optional[bool] = None,
|
|
827
827
|
output_hidden_states: Optional[bool] = None,
|
|
828
828
|
return_dict: Optional[bool] = None,
|
|
829
|
+
**kwargs,
|
|
829
830
|
) -> Union[tuple, TokenClassifierOutput]:
|
|
830
831
|
r"""
|
|
831
832
|
input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -419,6 +419,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|
|
419
419
|
output_hidden_states: Optional[bool] = None,
|
|
420
420
|
return_dict: Optional[bool] = None,
|
|
421
421
|
cache_position: Optional[torch.LongTensor] = None,
|
|
422
|
+
**kwargs,
|
|
422
423
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
|
423
424
|
r"""
|
|
424
425
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -773,6 +774,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
|
|
773
774
|
output_attentions: Optional[bool] = None,
|
|
774
775
|
output_hidden_states: Optional[bool] = None,
|
|
775
776
|
return_dict: Optional[bool] = None,
|
|
777
|
+
**kwargs,
|
|
776
778
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
|
|
777
779
|
r"""
|
|
778
780
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -894,6 +896,7 @@ class GPTNeoForTokenClassification(GPTNeoPreTrainedModel):
|
|
|
894
896
|
output_attentions: Optional[bool] = None,
|
|
895
897
|
output_hidden_states: Optional[bool] = None,
|
|
896
898
|
return_dict: Optional[bool] = None,
|
|
899
|
+
**kwargs,
|
|
897
900
|
) -> Union[tuple, TokenClassifierOutput]:
|
|
898
901
|
r"""
|
|
899
902
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -974,6 +977,7 @@ class GPTNeoForQuestionAnswering(GPTNeoPreTrainedModel):
|
|
|
974
977
|
output_attentions: Optional[bool] = None,
|
|
975
978
|
output_hidden_states: Optional[bool] = None,
|
|
976
979
|
return_dict: Optional[bool] = None,
|
|
980
|
+
**kwargs,
|
|
977
981
|
) -> Union[tuple, QuestionAnsweringModelOutput]:
|
|
978
982
|
r"""
|
|
979
983
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -28,7 +28,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
28
28
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
29
29
|
from ...processing_utils import Unpack
|
|
30
30
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
31
|
-
from ...utils.generic import check_model_inputs
|
|
31
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
32
32
|
from .configuration_gpt_neox import GPTNeoXConfig
|
|
33
33
|
|
|
34
34
|
|
|
@@ -107,7 +107,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
|
|
107
107
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
108
108
|
|
|
109
109
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
110
|
-
with
|
|
110
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
111
111
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
112
112
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
113
113
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -645,6 +645,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
|
|
|
645
645
|
use_cache: Optional[bool] = None,
|
|
646
646
|
output_attentions: Optional[bool] = None,
|
|
647
647
|
output_hidden_states: Optional[bool] = None,
|
|
648
|
+
**kwargs,
|
|
648
649
|
) -> SequenceClassifierOutputWithPast:
|
|
649
650
|
r"""
|
|
650
651
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -724,6 +725,7 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
|
|
|
724
725
|
use_cache: Optional[bool] = None,
|
|
725
726
|
output_attentions: Optional[bool] = None,
|
|
726
727
|
output_hidden_states: Optional[bool] = None,
|
|
728
|
+
**kwargs,
|
|
727
729
|
) -> TokenClassifierOutput:
|
|
728
730
|
r"""
|
|
729
731
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -783,6 +785,7 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel):
|
|
|
783
785
|
end_positions: Optional[torch.LongTensor] = None,
|
|
784
786
|
output_attentions: Optional[bool] = None,
|
|
785
787
|
output_hidden_states: Optional[bool] = None,
|
|
788
|
+
**kwargs,
|
|
786
789
|
) -> QuestionAnsweringModelOutput:
|
|
787
790
|
outputs: BaseModelOutputWithPast = self.gpt_neox(
|
|
788
791
|
input_ids,
|
|
@@ -518,6 +518,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
|
|
|
518
518
|
use_cache: Optional[bool] = None,
|
|
519
519
|
output_attentions: Optional[bool] = None,
|
|
520
520
|
output_hidden_states: Optional[bool] = None,
|
|
521
|
+
**kwargs,
|
|
521
522
|
) -> SequenceClassifierOutputWithPast:
|
|
522
523
|
r"""
|
|
523
524
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -597,6 +598,7 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
|
|
|
597
598
|
use_cache: Optional[bool] = None,
|
|
598
599
|
output_attentions: Optional[bool] = None,
|
|
599
600
|
output_hidden_states: Optional[bool] = None,
|
|
601
|
+
**kwargs,
|
|
600
602
|
) -> TokenClassifierOutput:
|
|
601
603
|
r"""
|
|
602
604
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -656,6 +658,7 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel):
|
|
|
656
658
|
end_positions: Optional[torch.LongTensor] = None,
|
|
657
659
|
output_attentions: Optional[bool] = None,
|
|
658
660
|
output_hidden_states: Optional[bool] = None,
|
|
661
|
+
**kwargs,
|
|
659
662
|
) -> QuestionAnsweringModelOutput:
|
|
660
663
|
outputs: BaseModelOutputWithPast = self.gpt_neox(
|
|
661
664
|
input_ids,
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Tokenization classes for GPTNeoX."""
|
|
16
16
|
|
|
17
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
|
|
20
20
|
from tokenizers.models import BPE
|
|
@@ -87,51 +87,34 @@ class GPTNeoXTokenizer(TokenizersBackend):
|
|
|
87
87
|
Whether or not to add an `eos_token` at the end of sequences.
|
|
88
88
|
trim_offsets (`bool`, *optional*, defaults to `True`):
|
|
89
89
|
Whether or not the post-processing step should trim offsets to avoid including whitespaces.
|
|
90
|
-
vocab (`dict`, *optional*):
|
|
91
|
-
Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file
|
|
92
|
-
merges (`list`, *optional*):
|
|
93
|
-
Custom merges list. If not provided, merges are loaded from merges_file
|
|
90
|
+
vocab (`str` or `dict[str, int]`, *optional*):
|
|
91
|
+
Custom vocabulary dictionary. If not provided, vocabulary is loaded from `vocab_file`.
|
|
92
|
+
merges (`str` or `list[str]`, *optional*):
|
|
93
|
+
Custom merges list. If not provided, merges are loaded from `merges_file`.
|
|
94
94
|
"""
|
|
95
95
|
|
|
96
96
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
97
97
|
model_input_names = ["input_ids", "attention_mask"]
|
|
98
|
-
|
|
98
|
+
model = BPE
|
|
99
99
|
|
|
100
100
|
def __init__(
|
|
101
101
|
self,
|
|
102
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
103
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
102
104
|
errors: str = "replace",
|
|
103
105
|
unk_token: str = "<|endoftext|>",
|
|
104
106
|
bos_token: str = "<|endoftext|>",
|
|
105
107
|
eos_token: str = "<|endoftext|>",
|
|
106
108
|
pad_token: str = "<|padding|>",
|
|
107
|
-
add_bos_token: bool = False,
|
|
108
|
-
add_eos_token: bool = False,
|
|
109
109
|
add_prefix_space: bool = False,
|
|
110
110
|
trim_offsets: bool = True,
|
|
111
|
-
vocab: Optional[dict] = None,
|
|
112
|
-
merges: Optional[list] = None,
|
|
113
111
|
**kwargs,
|
|
114
112
|
):
|
|
115
|
-
self._add_bos_token = add_bos_token
|
|
116
|
-
self._add_eos_token = add_eos_token
|
|
117
113
|
self.add_prefix_space = add_prefix_space
|
|
118
114
|
self.trim_offsets = trim_offsets
|
|
119
115
|
|
|
120
|
-
if vocab is not None:
|
|
121
|
-
|
|
122
|
-
{token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
|
|
123
|
-
)
|
|
124
|
-
else:
|
|
125
|
-
self._vocab = {
|
|
126
|
-
str(unk_token): 0,
|
|
127
|
-
str(pad_token): 1,
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
if merges is not None:
|
|
131
|
-
self._merges = merges
|
|
132
|
-
else:
|
|
133
|
-
self._merges = []
|
|
134
|
-
|
|
116
|
+
self._vocab = vocab if vocab is not None else {str(unk_token): 0, str(pad_token): 1}
|
|
117
|
+
self._merges = merges or []
|
|
135
118
|
self._tokenizer = Tokenizer(
|
|
136
119
|
BPE(
|
|
137
120
|
vocab=self._vocab,
|
|
@@ -149,38 +132,16 @@ class GPTNeoXTokenizer(TokenizersBackend):
|
|
|
149
132
|
)
|
|
150
133
|
self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=False, trim_offsets=True)
|
|
151
134
|
|
|
152
|
-
tokenizer_object = self._tokenizer
|
|
153
|
-
|
|
154
135
|
super().__init__(
|
|
155
|
-
tokenizer_object=tokenizer_object,
|
|
156
136
|
errors=errors,
|
|
157
137
|
unk_token=unk_token,
|
|
158
138
|
bos_token=bos_token,
|
|
159
139
|
eos_token=eos_token,
|
|
160
140
|
pad_token=pad_token,
|
|
161
|
-
add_bos_token=add_bos_token,
|
|
162
|
-
add_eos_token=add_eos_token,
|
|
163
141
|
add_prefix_space=add_prefix_space,
|
|
164
142
|
trim_offsets=trim_offsets,
|
|
165
143
|
**kwargs,
|
|
166
144
|
)
|
|
167
145
|
|
|
168
|
-
self.update_post_processor()
|
|
169
|
-
|
|
170
|
-
def _post_init(self):
|
|
171
|
-
"""Post-initialization to ensure tokenizer settings are applied correctly."""
|
|
172
|
-
# Re-apply settings to ensure they're correct after loading from pretrained
|
|
173
|
-
self._tokenizer.normalizer = normalizers.NFC()
|
|
174
|
-
self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(
|
|
175
|
-
add_prefix_space=self.add_prefix_space, trim_offsets=self.trim_offsets
|
|
176
|
-
)
|
|
177
|
-
self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=False, trim_offsets=True)
|
|
178
|
-
|
|
179
|
-
# Call parent to handle AddedToken properties
|
|
180
|
-
super()._post_init()
|
|
181
|
-
|
|
182
|
-
# Update post processor with current bos/eos settings
|
|
183
|
-
self.update_post_processor()
|
|
184
|
-
|
|
185
146
|
|
|
186
147
|
__all__ = ["GPTNeoXTokenizer"]
|
|
@@ -30,6 +30,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
|
30
30
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
31
31
|
from ...modeling_utils import PreTrainedModel
|
|
32
32
|
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
|
33
|
+
from ...utils.generic import maybe_autocast
|
|
33
34
|
from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig
|
|
34
35
|
|
|
35
36
|
|
|
@@ -116,7 +117,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
|
|
|
116
117
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
117
118
|
|
|
118
119
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
119
|
-
with
|
|
120
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
120
121
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
121
122
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
122
123
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -431,6 +432,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
|
|
431
432
|
output_hidden_states: Optional[bool] = None,
|
|
432
433
|
return_dict: Optional[bool] = None,
|
|
433
434
|
cache_position: Optional[torch.LongTensor] = None,
|
|
435
|
+
**kwargs,
|
|
434
436
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
435
437
|
r"""
|
|
436
438
|
Example:
|
|
@@ -28,6 +28,7 @@ from torch.nn import functional as F
|
|
|
28
28
|
from ... import initialization as init
|
|
29
29
|
from ...cache_utils import Cache, DynamicCache
|
|
30
30
|
from ...generation import GenerationMixin
|
|
31
|
+
from ...integrations import use_kernelized_func
|
|
31
32
|
from ...integrations.hub_kernels import use_kernel_forward_from_hub
|
|
32
33
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
33
34
|
from ...modeling_layers import (
|
|
@@ -40,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
40
41
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
41
42
|
from ...processing_utils import Unpack
|
|
42
43
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
43
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
44
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
44
45
|
from .configuration_gpt_oss import GptOssConfig
|
|
45
46
|
|
|
46
47
|
|
|
@@ -235,7 +236,7 @@ class GptOssRotaryEmbedding(nn.Module):
|
|
|
235
236
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
236
237
|
|
|
237
238
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
238
|
-
with
|
|
239
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
239
240
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
240
241
|
emb = freqs
|
|
241
242
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -301,12 +302,13 @@ def eager_attention_forward(
|
|
|
301
302
|
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
|
|
302
303
|
probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
|
|
303
304
|
scores = probs[..., :-1] # we drop the sink here
|
|
304
|
-
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
|
|
305
|
+
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training).to(value_states.dtype)
|
|
305
306
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
306
307
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
307
308
|
return attn_output, attn_weights
|
|
308
309
|
|
|
309
310
|
|
|
311
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
310
312
|
class GptOssAttention(nn.Module):
|
|
311
313
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
312
314
|
|
|
@@ -332,7 +334,6 @@ class GptOssAttention(nn.Module):
|
|
|
332
334
|
self.o_proj = nn.Linear(
|
|
333
335
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
334
336
|
)
|
|
335
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
336
337
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
337
338
|
self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))
|
|
338
339
|
|
|
@@ -343,7 +344,6 @@ class GptOssAttention(nn.Module):
|
|
|
343
344
|
attention_mask: Optional[torch.Tensor],
|
|
344
345
|
past_key_values: Optional[Cache] = None,
|
|
345
346
|
cache_position: Optional[torch.LongTensor] = None,
|
|
346
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
347
347
|
**kwargs: Unpack[TransformersKwargs],
|
|
348
348
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
349
349
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -373,7 +373,6 @@ class GptOssAttention(nn.Module):
|
|
|
373
373
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
374
374
|
scaling=self.scaling,
|
|
375
375
|
sliding_window=self.sliding_window,
|
|
376
|
-
position_ids=position_ids,
|
|
377
376
|
s_aux=self.sinks, # diff with Llama
|
|
378
377
|
**kwargs,
|
|
379
378
|
)
|
|
@@ -34,7 +34,7 @@ from ...utils import (
|
|
|
34
34
|
auto_docstring,
|
|
35
35
|
logging,
|
|
36
36
|
)
|
|
37
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
37
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
38
38
|
from ..llama.modeling_llama import (
|
|
39
39
|
LlamaDecoderLayer,
|
|
40
40
|
LlamaPreTrainedModel,
|
|
@@ -185,7 +185,7 @@ class GptOssRotaryEmbedding(Qwen2RotaryEmbedding):
|
|
|
185
185
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
186
186
|
|
|
187
187
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
188
|
-
with
|
|
188
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
189
189
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
190
190
|
emb = freqs
|
|
191
191
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -239,7 +239,7 @@ def eager_attention_forward(
|
|
|
239
239
|
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
|
|
240
240
|
probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
|
|
241
241
|
scores = probs[..., :-1] # we drop the sink here
|
|
242
|
-
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
|
|
242
|
+
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training).to(value_states.dtype)
|
|
243
243
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
244
244
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
245
245
|
return attn_output, attn_weights
|
|
@@ -269,7 +269,6 @@ class GptOssAttention(Qwen2Attention):
|
|
|
269
269
|
attention_mask: Optional[torch.Tensor],
|
|
270
270
|
past_key_values: Optional[Cache] = None,
|
|
271
271
|
cache_position: Optional[torch.LongTensor] = None,
|
|
272
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
273
272
|
**kwargs: Unpack[TransformersKwargs],
|
|
274
273
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
275
274
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -299,7 +298,6 @@ class GptOssAttention(Qwen2Attention):
|
|
|
299
298
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
300
299
|
scaling=self.scaling,
|
|
301
300
|
sliding_window=self.sliding_window,
|
|
302
|
-
position_ids=position_ids,
|
|
303
301
|
s_aux=self.sinks, # diff with Llama
|
|
304
302
|
**kwargs,
|
|
305
303
|
)
|
|
@@ -482,6 +482,7 @@ class GPTJModel(GPTJPreTrainedModel):
|
|
|
482
482
|
output_hidden_states: Optional[bool] = None,
|
|
483
483
|
return_dict: Optional[bool] = None,
|
|
484
484
|
cache_position: Optional[torch.LongTensor] = None,
|
|
485
|
+
**kwargs,
|
|
485
486
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
486
487
|
r"""
|
|
487
488
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
|
|
@@ -819,6 +820,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
|
|
819
820
|
output_attentions: Optional[bool] = None,
|
|
820
821
|
output_hidden_states: Optional[bool] = None,
|
|
821
822
|
return_dict: Optional[bool] = None,
|
|
823
|
+
**kwargs,
|
|
822
824
|
) -> Union[tuple, SequenceClassifierOutputWithPast]:
|
|
823
825
|
r"""
|
|
824
826
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
|
|
@@ -930,6 +932,7 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel):
|
|
|
930
932
|
output_attentions: Optional[bool] = None,
|
|
931
933
|
output_hidden_states: Optional[bool] = None,
|
|
932
934
|
return_dict: Optional[bool] = None,
|
|
935
|
+
**kwargs,
|
|
933
936
|
) -> Union[tuple, QuestionAnsweringModelOutput]:
|
|
934
937
|
r"""
|
|
935
938
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
|
|
@@ -28,7 +28,7 @@ from torch import nn
|
|
|
28
28
|
from ...activations import ACT2FN
|
|
29
29
|
from ...cache_utils import Cache, DynamicCache
|
|
30
30
|
from ...generation import GenerationMixin
|
|
31
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
31
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
32
32
|
from ...masking_utils import create_causal_mask
|
|
33
33
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
34
34
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -36,7 +36,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
36
36
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
37
37
|
from ...processing_utils import Unpack
|
|
38
38
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
39
|
-
from ...utils.generic import check_model_inputs
|
|
39
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
40
40
|
from .configuration_granite import GraniteConfig
|
|
41
41
|
|
|
42
42
|
|
|
@@ -116,6 +116,7 @@ def eager_attention_forward(
|
|
|
116
116
|
return attn_output, attn_weights
|
|
117
117
|
|
|
118
118
|
|
|
119
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
119
120
|
class GraniteAttention(nn.Module):
|
|
120
121
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
121
122
|
|
|
@@ -141,7 +142,6 @@ class GraniteAttention(nn.Module):
|
|
|
141
142
|
self.o_proj = nn.Linear(
|
|
142
143
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
143
144
|
)
|
|
144
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
145
145
|
|
|
146
146
|
def forward(
|
|
147
147
|
self,
|
|
@@ -376,7 +376,7 @@ class GraniteRotaryEmbedding(nn.Module):
|
|
|
376
376
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
377
377
|
|
|
378
378
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
379
|
-
with
|
|
379
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
380
380
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
381
381
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
382
382
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -30,7 +30,7 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
33
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask
|
|
35
35
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
36
36
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
38
38
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
40
|
from ...utils import TransformersKwargs, auto_docstring
|
|
41
|
-
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
41
|
+
from ...utils.generic import can_return_tuple, check_model_inputs, maybe_autocast
|
|
42
42
|
from .configuration_granitemoe import GraniteMoeConfig
|
|
43
43
|
|
|
44
44
|
|
|
@@ -119,7 +119,7 @@ class GraniteMoeRotaryEmbedding(nn.Module):
|
|
|
119
119
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
120
120
|
|
|
121
121
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
122
|
-
with
|
|
122
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
123
123
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
124
124
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
125
125
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -338,6 +338,7 @@ def eager_attention_forward(
|
|
|
338
338
|
return attn_output, attn_weights
|
|
339
339
|
|
|
340
340
|
|
|
341
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
341
342
|
class GraniteMoeAttention(nn.Module):
|
|
342
343
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
343
344
|
|
|
@@ -363,7 +364,6 @@ class GraniteMoeAttention(nn.Module):
|
|
|
363
364
|
self.o_proj = nn.Linear(
|
|
364
365
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
365
366
|
)
|
|
366
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
367
367
|
|
|
368
368
|
def forward(
|
|
369
369
|
self,
|
|
@@ -714,8 +714,6 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
|
|
|
714
714
|
|
|
715
715
|
loss = None
|
|
716
716
|
if labels is not None:
|
|
717
|
-
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
718
|
-
logits = logits.float()
|
|
719
717
|
# Flatten the tokens
|
|
720
718
|
loss = self.loss_function(
|
|
721
719
|
logits,
|
|
@@ -295,8 +295,6 @@ class GraniteMoeForCausalLM(MixtralForCausalLM):
|
|
|
295
295
|
|
|
296
296
|
loss = None
|
|
297
297
|
if labels is not None:
|
|
298
|
-
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
299
|
-
logits = logits.float()
|
|
300
298
|
# Flatten the tokens
|
|
301
299
|
loss = self.loss_function(
|
|
302
300
|
logits,
|
|
@@ -31,7 +31,7 @@ from transformers.activations import ACT2FN
|
|
|
31
31
|
from ... import initialization as init
|
|
32
32
|
from ...cache_utils import Cache
|
|
33
33
|
from ...generation import GenerationMixin
|
|
34
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
34
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
35
35
|
from ...masking_utils import create_causal_mask
|
|
36
36
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
37
37
|
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
@@ -39,7 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
39
39
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
40
40
|
from ...processing_utils import Unpack
|
|
41
41
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
42
|
-
from ...utils.generic import check_model_inputs
|
|
42
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
43
43
|
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
|
44
44
|
from .configuration_granitemoehybrid import GraniteMoeHybridConfig
|
|
45
45
|
|
|
@@ -132,6 +132,7 @@ def eager_attention_forward(
|
|
|
132
132
|
return attn_output, attn_weights
|
|
133
133
|
|
|
134
134
|
|
|
135
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
135
136
|
class GraniteMoeHybridAttention(nn.Module):
|
|
136
137
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
137
138
|
|
|
@@ -157,7 +158,6 @@ class GraniteMoeHybridAttention(nn.Module):
|
|
|
157
158
|
self.o_proj = nn.Linear(
|
|
158
159
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
159
160
|
)
|
|
160
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
161
161
|
|
|
162
162
|
def forward(
|
|
163
163
|
self,
|
|
@@ -954,7 +954,7 @@ class GraniteMoeHybridRotaryEmbedding(nn.Module):
|
|
|
954
954
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
955
955
|
|
|
956
956
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
957
|
-
with
|
|
957
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
958
958
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
959
959
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
960
960
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -1510,8 +1510,6 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
|
|
|
1510
1510
|
|
|
1511
1511
|
loss = None
|
|
1512
1512
|
if labels is not None:
|
|
1513
|
-
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
1514
|
-
logits = logits.float()
|
|
1515
1513
|
# Flatten the tokens
|
|
1516
1514
|
loss = self.loss_function(
|
|
1517
1515
|
logits,
|