transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -20,6 +20,7 @@ import torch
|
|
|
20
20
|
from torch import nn
|
|
21
21
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
22
22
|
|
|
23
|
+
from ... import initialization as init
|
|
23
24
|
from ...activations import ACT2FN
|
|
24
25
|
from ...cache_utils import Cache, DynamicCache
|
|
25
26
|
from ...generation import GenerationMixin
|
|
@@ -70,11 +71,11 @@ class GPTNeoSelfAttention(nn.Module):
|
|
|
70
71
|
# local causal self attention is a sliding window where each token can only attend to the previous
|
|
71
72
|
# window_size tokens. This is implemented by updating the causal mask such that for each token
|
|
72
73
|
# all other tokens are masked except the previous window_size tokens.
|
|
74
|
+
self.attention_type = attention_type
|
|
73
75
|
if attention_type == "local":
|
|
74
76
|
bias = torch.bitwise_xor(bias, torch.tril(bias, -config.window_size))
|
|
75
77
|
|
|
76
78
|
self.register_buffer("bias", bias, persistent=False)
|
|
77
|
-
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
|
78
79
|
|
|
79
80
|
self.attn_dropout = nn.Dropout(float(config.attention_dropout))
|
|
80
81
|
self.resid_dropout = nn.Dropout(float(config.resid_dropout))
|
|
@@ -237,8 +238,8 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
|
|
|
237
238
|
else torch.get_autocast_gpu_dtype()
|
|
238
239
|
)
|
|
239
240
|
# Handle the case where the model is quantized
|
|
240
|
-
elif hasattr(self.config, "
|
|
241
|
-
target_dtype = self.config.
|
|
241
|
+
elif hasattr(self.config, "quantization_config"):
|
|
242
|
+
target_dtype = self.config.dtype
|
|
242
243
|
else:
|
|
243
244
|
target_dtype = self.q_proj.weight.dtype
|
|
244
245
|
|
|
@@ -382,6 +383,17 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
|
|
|
382
383
|
_supports_flash_attn = True
|
|
383
384
|
_can_compile_fullgraph = False # TODO: needs a hybrid cache
|
|
384
385
|
|
|
386
|
+
def _init_weights(self, module):
|
|
387
|
+
super()._init_weights(module)
|
|
388
|
+
if isinstance(module, GPTNeoSelfAttention):
|
|
389
|
+
max_positions = module.config.max_position_embeddings
|
|
390
|
+
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view(
|
|
391
|
+
1, 1, max_positions, max_positions
|
|
392
|
+
)
|
|
393
|
+
if module.attention_type == "local":
|
|
394
|
+
bias = torch.bitwise_xor(bias, torch.tril(bias, -module.config.window_size))
|
|
395
|
+
init.copy_(module.bias, bias)
|
|
396
|
+
|
|
385
397
|
|
|
386
398
|
@auto_docstring
|
|
387
399
|
class GPTNeoModel(GPTNeoPreTrainedModel):
|
|
@@ -419,6 +431,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|
|
419
431
|
output_hidden_states: Optional[bool] = None,
|
|
420
432
|
return_dict: Optional[bool] = None,
|
|
421
433
|
cache_position: Optional[torch.LongTensor] = None,
|
|
434
|
+
**kwargs,
|
|
422
435
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
|
423
436
|
r"""
|
|
424
437
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -773,6 +786,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
|
|
773
786
|
output_attentions: Optional[bool] = None,
|
|
774
787
|
output_hidden_states: Optional[bool] = None,
|
|
775
788
|
return_dict: Optional[bool] = None,
|
|
789
|
+
**kwargs,
|
|
776
790
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
|
|
777
791
|
r"""
|
|
778
792
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -894,6 +908,7 @@ class GPTNeoForTokenClassification(GPTNeoPreTrainedModel):
|
|
|
894
908
|
output_attentions: Optional[bool] = None,
|
|
895
909
|
output_hidden_states: Optional[bool] = None,
|
|
896
910
|
return_dict: Optional[bool] = None,
|
|
911
|
+
**kwargs,
|
|
897
912
|
) -> Union[tuple, TokenClassifierOutput]:
|
|
898
913
|
r"""
|
|
899
914
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -974,6 +989,7 @@ class GPTNeoForQuestionAnswering(GPTNeoPreTrainedModel):
|
|
|
974
989
|
output_attentions: Optional[bool] = None,
|
|
975
990
|
output_hidden_states: Optional[bool] = None,
|
|
976
991
|
return_dict: Optional[bool] = None,
|
|
992
|
+
**kwargs,
|
|
977
993
|
) -> Union[tuple, QuestionAnsweringModelOutput]:
|
|
978
994
|
r"""
|
|
979
995
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -28,7 +28,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
28
28
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
29
29
|
from ...processing_utils import Unpack
|
|
30
30
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
31
|
-
from ...utils.generic import check_model_inputs
|
|
31
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
32
32
|
from .configuration_gpt_neox import GPTNeoXConfig
|
|
33
33
|
|
|
34
34
|
|
|
@@ -66,7 +66,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
|
|
66
66
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
67
67
|
|
|
68
68
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
69
|
-
self.original_inv_freq =
|
|
69
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
70
70
|
|
|
71
71
|
@staticmethod
|
|
72
72
|
def compute_default_rope_parameters(
|
|
@@ -107,7 +107,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
|
|
107
107
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
108
108
|
|
|
109
109
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
110
|
-
with
|
|
110
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
111
111
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
112
112
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
113
113
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -645,6 +645,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
|
|
|
645
645
|
use_cache: Optional[bool] = None,
|
|
646
646
|
output_attentions: Optional[bool] = None,
|
|
647
647
|
output_hidden_states: Optional[bool] = None,
|
|
648
|
+
**kwargs,
|
|
648
649
|
) -> SequenceClassifierOutputWithPast:
|
|
649
650
|
r"""
|
|
650
651
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -724,6 +725,7 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
|
|
|
724
725
|
use_cache: Optional[bool] = None,
|
|
725
726
|
output_attentions: Optional[bool] = None,
|
|
726
727
|
output_hidden_states: Optional[bool] = None,
|
|
728
|
+
**kwargs,
|
|
727
729
|
) -> TokenClassifierOutput:
|
|
728
730
|
r"""
|
|
729
731
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -783,6 +785,7 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel):
|
|
|
783
785
|
end_positions: Optional[torch.LongTensor] = None,
|
|
784
786
|
output_attentions: Optional[bool] = None,
|
|
785
787
|
output_hidden_states: Optional[bool] = None,
|
|
788
|
+
**kwargs,
|
|
786
789
|
) -> QuestionAnsweringModelOutput:
|
|
787
790
|
outputs: BaseModelOutputWithPast = self.gpt_neox(
|
|
788
791
|
input_ids,
|
|
@@ -518,6 +518,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
|
|
|
518
518
|
use_cache: Optional[bool] = None,
|
|
519
519
|
output_attentions: Optional[bool] = None,
|
|
520
520
|
output_hidden_states: Optional[bool] = None,
|
|
521
|
+
**kwargs,
|
|
521
522
|
) -> SequenceClassifierOutputWithPast:
|
|
522
523
|
r"""
|
|
523
524
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -597,6 +598,7 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
|
|
|
597
598
|
use_cache: Optional[bool] = None,
|
|
598
599
|
output_attentions: Optional[bool] = None,
|
|
599
600
|
output_hidden_states: Optional[bool] = None,
|
|
601
|
+
**kwargs,
|
|
600
602
|
) -> TokenClassifierOutput:
|
|
601
603
|
r"""
|
|
602
604
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -656,6 +658,7 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel):
|
|
|
656
658
|
end_positions: Optional[torch.LongTensor] = None,
|
|
657
659
|
output_attentions: Optional[bool] = None,
|
|
658
660
|
output_hidden_states: Optional[bool] = None,
|
|
661
|
+
**kwargs,
|
|
659
662
|
) -> QuestionAnsweringModelOutput:
|
|
660
663
|
outputs: BaseModelOutputWithPast = self.gpt_neox(
|
|
661
664
|
input_ids,
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Tokenization classes for GPTNeoX."""
|
|
16
16
|
|
|
17
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
|
|
20
20
|
from tokenizers.models import BPE
|
|
@@ -87,51 +87,34 @@ class GPTNeoXTokenizer(TokenizersBackend):
|
|
|
87
87
|
Whether or not to add an `eos_token` at the end of sequences.
|
|
88
88
|
trim_offsets (`bool`, *optional*, defaults to `True`):
|
|
89
89
|
Whether or not the post-processing step should trim offsets to avoid including whitespaces.
|
|
90
|
-
vocab (`dict`, *optional*):
|
|
91
|
-
Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file
|
|
92
|
-
merges (`list`, *optional*):
|
|
93
|
-
Custom merges list. If not provided, merges are loaded from merges_file
|
|
90
|
+
vocab (`str` or `dict[str, int]`, *optional*):
|
|
91
|
+
Custom vocabulary dictionary. If not provided, vocabulary is loaded from `vocab_file`.
|
|
92
|
+
merges (`str` or `list[str]`, *optional*):
|
|
93
|
+
Custom merges list. If not provided, merges are loaded from `merges_file`.
|
|
94
94
|
"""
|
|
95
95
|
|
|
96
96
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
97
97
|
model_input_names = ["input_ids", "attention_mask"]
|
|
98
|
-
|
|
98
|
+
model = BPE
|
|
99
99
|
|
|
100
100
|
def __init__(
|
|
101
101
|
self,
|
|
102
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
103
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
102
104
|
errors: str = "replace",
|
|
103
105
|
unk_token: str = "<|endoftext|>",
|
|
104
106
|
bos_token: str = "<|endoftext|>",
|
|
105
107
|
eos_token: str = "<|endoftext|>",
|
|
106
108
|
pad_token: str = "<|padding|>",
|
|
107
|
-
add_bos_token: bool = False,
|
|
108
|
-
add_eos_token: bool = False,
|
|
109
109
|
add_prefix_space: bool = False,
|
|
110
110
|
trim_offsets: bool = True,
|
|
111
|
-
vocab: Optional[dict] = None,
|
|
112
|
-
merges: Optional[list] = None,
|
|
113
111
|
**kwargs,
|
|
114
112
|
):
|
|
115
|
-
self._add_bos_token = add_bos_token
|
|
116
|
-
self._add_eos_token = add_eos_token
|
|
117
113
|
self.add_prefix_space = add_prefix_space
|
|
118
114
|
self.trim_offsets = trim_offsets
|
|
119
115
|
|
|
120
|
-
if vocab is not None:
|
|
121
|
-
|
|
122
|
-
{token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
|
|
123
|
-
)
|
|
124
|
-
else:
|
|
125
|
-
self._vocab = {
|
|
126
|
-
str(unk_token): 0,
|
|
127
|
-
str(pad_token): 1,
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
if merges is not None:
|
|
131
|
-
self._merges = merges
|
|
132
|
-
else:
|
|
133
|
-
self._merges = []
|
|
134
|
-
|
|
116
|
+
self._vocab = vocab if vocab is not None else {str(unk_token): 0, str(pad_token): 1}
|
|
117
|
+
self._merges = merges or []
|
|
135
118
|
self._tokenizer = Tokenizer(
|
|
136
119
|
BPE(
|
|
137
120
|
vocab=self._vocab,
|
|
@@ -149,38 +132,16 @@ class GPTNeoXTokenizer(TokenizersBackend):
|
|
|
149
132
|
)
|
|
150
133
|
self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=False, trim_offsets=True)
|
|
151
134
|
|
|
152
|
-
tokenizer_object = self._tokenizer
|
|
153
|
-
|
|
154
135
|
super().__init__(
|
|
155
|
-
tokenizer_object=tokenizer_object,
|
|
156
136
|
errors=errors,
|
|
157
137
|
unk_token=unk_token,
|
|
158
138
|
bos_token=bos_token,
|
|
159
139
|
eos_token=eos_token,
|
|
160
140
|
pad_token=pad_token,
|
|
161
|
-
add_bos_token=add_bos_token,
|
|
162
|
-
add_eos_token=add_eos_token,
|
|
163
141
|
add_prefix_space=add_prefix_space,
|
|
164
142
|
trim_offsets=trim_offsets,
|
|
165
143
|
**kwargs,
|
|
166
144
|
)
|
|
167
145
|
|
|
168
|
-
self.update_post_processor()
|
|
169
|
-
|
|
170
|
-
def _post_init(self):
|
|
171
|
-
"""Post-initialization to ensure tokenizer settings are applied correctly."""
|
|
172
|
-
# Re-apply settings to ensure they're correct after loading from pretrained
|
|
173
|
-
self._tokenizer.normalizer = normalizers.NFC()
|
|
174
|
-
self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(
|
|
175
|
-
add_prefix_space=self.add_prefix_space, trim_offsets=self.trim_offsets
|
|
176
|
-
)
|
|
177
|
-
self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=False, trim_offsets=True)
|
|
178
|
-
|
|
179
|
-
# Call parent to handle AddedToken properties
|
|
180
|
-
super()._post_init()
|
|
181
|
-
|
|
182
|
-
# Update post processor with current bos/eos settings
|
|
183
|
-
self.update_post_processor()
|
|
184
|
-
|
|
185
146
|
|
|
186
147
|
__all__ = ["GPTNeoXTokenizer"]
|
|
@@ -30,6 +30,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
|
30
30
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
31
31
|
from ...modeling_utils import PreTrainedModel
|
|
32
32
|
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
|
33
|
+
from ...utils.generic import maybe_autocast
|
|
33
34
|
from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig
|
|
34
35
|
|
|
35
36
|
|
|
@@ -77,7 +78,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
|
|
|
77
78
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
78
79
|
|
|
79
80
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
80
|
-
self.original_inv_freq =
|
|
81
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
81
82
|
|
|
82
83
|
@staticmethod
|
|
83
84
|
def compute_default_rope_parameters(
|
|
@@ -116,7 +117,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
|
|
|
116
117
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
117
118
|
|
|
118
119
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
119
|
-
with
|
|
120
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
120
121
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
121
122
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
122
123
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -431,6 +432,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
|
|
431
432
|
output_hidden_states: Optional[bool] = None,
|
|
432
433
|
return_dict: Optional[bool] = None,
|
|
433
434
|
cache_position: Optional[torch.LongTensor] = None,
|
|
435
|
+
**kwargs,
|
|
434
436
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
435
437
|
r"""
|
|
436
438
|
Example:
|
|
@@ -117,5 +117,22 @@ class GptOssConfig(PreTrainedConfig):
|
|
|
117
117
|
**kwargs,
|
|
118
118
|
)
|
|
119
119
|
|
|
120
|
+
def __setattr__(self, key, value):
|
|
121
|
+
"""
|
|
122
|
+
Overwritten to allow checking for the proper attention implementation to be used.
|
|
123
|
+
|
|
124
|
+
Due to `set_attn_implementation` which internally assigns `_attn_implementation_internal = "..."`, simply overwriting
|
|
125
|
+
the specific attention setter is not enough. Using a property/setter for `_attn_implementation_internal` would result in
|
|
126
|
+
a recursive dependency (as `_attn_implementation` acts as a wrapper around `_attn_implementation_internal`) - hence, this
|
|
127
|
+
workaround.
|
|
128
|
+
"""
|
|
129
|
+
if key in ("_attn_implementation", "_attn_implementation_internal"):
|
|
130
|
+
if value and "flash" in value and value.removeprefix("paged|") != "kernels-community/vllm-flash-attn3":
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"GPT-OSS model does not support the specified flash attention implementation: {value}. "
|
|
133
|
+
"Only `kernels-community/vllm-flash-attn3` is supported."
|
|
134
|
+
)
|
|
135
|
+
super().__setattr__(key, value)
|
|
136
|
+
|
|
120
137
|
|
|
121
138
|
__all__ = ["GptOssConfig"]
|
|
@@ -28,7 +28,7 @@ from torch.nn import functional as F
|
|
|
28
28
|
from ... import initialization as init
|
|
29
29
|
from ...cache_utils import Cache, DynamicCache
|
|
30
30
|
from ...generation import GenerationMixin
|
|
31
|
-
from ...integrations
|
|
31
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
32
32
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
33
33
|
from ...modeling_layers import (
|
|
34
34
|
GenericForSequenceClassification,
|
|
@@ -40,7 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
40
40
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
41
41
|
from ...processing_utils import Unpack
|
|
42
42
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
43
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
43
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
44
44
|
from .configuration_gpt_oss import GptOssConfig
|
|
45
45
|
|
|
46
46
|
|
|
@@ -88,8 +88,8 @@ class GptOssExperts(nn.Module):
|
|
|
88
88
|
|
|
89
89
|
Args:
|
|
90
90
|
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
|
|
91
|
-
selected_experts (torch.Tensor): (batch_size *
|
|
92
|
-
routing_weights (torch.Tensor): (batch_size *
|
|
91
|
+
selected_experts (torch.Tensor): (batch_size * seq_len, top_k)
|
|
92
|
+
routing_weights (torch.Tensor): (batch_size * seq_len, top_k)
|
|
93
93
|
Returns:
|
|
94
94
|
torch.Tensor
|
|
95
95
|
"""
|
|
@@ -159,8 +159,8 @@ class GptOssTopKRouter(nn.Module):
|
|
|
159
159
|
|
|
160
160
|
def forward(self, hidden_states):
|
|
161
161
|
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
|
162
|
-
router_logits = F.linear(hidden_states, self.weight, self.bias) # (
|
|
163
|
-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (
|
|
162
|
+
router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts)
|
|
163
|
+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k)
|
|
164
164
|
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
|
|
165
165
|
router_scores = router_top_value
|
|
166
166
|
return router_logits, router_scores, router_indices
|
|
@@ -196,7 +196,7 @@ class GptOssRotaryEmbedding(nn.Module):
|
|
|
196
196
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
197
197
|
|
|
198
198
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
199
|
-
self.original_inv_freq =
|
|
199
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
200
200
|
|
|
201
201
|
@staticmethod
|
|
202
202
|
def compute_default_rope_parameters(
|
|
@@ -235,7 +235,7 @@ class GptOssRotaryEmbedding(nn.Module):
|
|
|
235
235
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
236
236
|
|
|
237
237
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
238
|
-
with
|
|
238
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
239
239
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
240
240
|
emb = freqs
|
|
241
241
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -301,12 +301,13 @@ def eager_attention_forward(
|
|
|
301
301
|
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
|
|
302
302
|
probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
|
|
303
303
|
scores = probs[..., :-1] # we drop the sink here
|
|
304
|
-
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
|
|
304
|
+
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training).to(value_states.dtype)
|
|
305
305
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
306
306
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
307
307
|
return attn_output, attn_weights
|
|
308
308
|
|
|
309
309
|
|
|
310
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
310
311
|
class GptOssAttention(nn.Module):
|
|
311
312
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
312
313
|
|
|
@@ -332,7 +333,6 @@ class GptOssAttention(nn.Module):
|
|
|
332
333
|
self.o_proj = nn.Linear(
|
|
333
334
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
334
335
|
)
|
|
335
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
336
336
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
337
337
|
self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))
|
|
338
338
|
|
|
@@ -343,7 +343,6 @@ class GptOssAttention(nn.Module):
|
|
|
343
343
|
attention_mask: Optional[torch.Tensor],
|
|
344
344
|
past_key_values: Optional[Cache] = None,
|
|
345
345
|
cache_position: Optional[torch.LongTensor] = None,
|
|
346
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
347
346
|
**kwargs: Unpack[TransformersKwargs],
|
|
348
347
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
349
348
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -373,7 +372,6 @@ class GptOssAttention(nn.Module):
|
|
|
373
372
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
374
373
|
scaling=self.scaling,
|
|
375
374
|
sliding_window=self.sliding_window,
|
|
376
|
-
position_ids=position_ids,
|
|
377
375
|
s_aux=self.sinks, # diff with Llama
|
|
378
376
|
**kwargs,
|
|
379
377
|
)
|
|
@@ -446,8 +444,6 @@ class GptOssPreTrainedModel(PreTrainedModel):
|
|
|
446
444
|
"attentions": GptOssAttention,
|
|
447
445
|
}
|
|
448
446
|
_keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
|
|
449
|
-
_supports_flash_attention = False
|
|
450
|
-
_supports_flex_attention = False
|
|
451
447
|
|
|
452
448
|
@torch.no_grad()
|
|
453
449
|
def _init_weights(self, module):
|
|
@@ -21,7 +21,7 @@ from torch.nn import functional as F
|
|
|
21
21
|
|
|
22
22
|
from ... import initialization as init
|
|
23
23
|
from ...cache_utils import Cache, DynamicCache
|
|
24
|
-
from ...integrations
|
|
24
|
+
from ...integrations import use_kernel_forward_from_hub
|
|
25
25
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
26
26
|
from ...modeling_outputs import (
|
|
27
27
|
MoeModelOutputWithPast,
|
|
@@ -34,7 +34,7 @@ from ...utils import (
|
|
|
34
34
|
auto_docstring,
|
|
35
35
|
logging,
|
|
36
36
|
)
|
|
37
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
37
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
38
38
|
from ..llama.modeling_llama import (
|
|
39
39
|
LlamaDecoderLayer,
|
|
40
40
|
LlamaPreTrainedModel,
|
|
@@ -86,8 +86,8 @@ class GptOssExperts(nn.Module):
|
|
|
86
86
|
|
|
87
87
|
Args:
|
|
88
88
|
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
|
|
89
|
-
selected_experts (torch.Tensor): (batch_size *
|
|
90
|
-
routing_weights (torch.Tensor): (batch_size *
|
|
89
|
+
selected_experts (torch.Tensor): (batch_size * seq_len, top_k)
|
|
90
|
+
routing_weights (torch.Tensor): (batch_size * seq_len, top_k)
|
|
91
91
|
Returns:
|
|
92
92
|
torch.Tensor
|
|
93
93
|
"""
|
|
@@ -157,8 +157,8 @@ class GptOssTopKRouter(nn.Module):
|
|
|
157
157
|
|
|
158
158
|
def forward(self, hidden_states):
|
|
159
159
|
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
|
160
|
-
router_logits = F.linear(hidden_states, self.weight, self.bias) # (
|
|
161
|
-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (
|
|
160
|
+
router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts)
|
|
161
|
+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k)
|
|
162
162
|
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
|
|
163
163
|
router_scores = router_top_value
|
|
164
164
|
return router_logits, router_scores, router_indices
|
|
@@ -185,7 +185,7 @@ class GptOssRotaryEmbedding(Qwen2RotaryEmbedding):
|
|
|
185
185
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
186
186
|
|
|
187
187
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
188
|
-
with
|
|
188
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
189
189
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
190
190
|
emb = freqs
|
|
191
191
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -239,7 +239,7 @@ def eager_attention_forward(
|
|
|
239
239
|
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
|
|
240
240
|
probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
|
|
241
241
|
scores = probs[..., :-1] # we drop the sink here
|
|
242
|
-
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
|
|
242
|
+
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training).to(value_states.dtype)
|
|
243
243
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
244
244
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
245
245
|
return attn_output, attn_weights
|
|
@@ -269,7 +269,6 @@ class GptOssAttention(Qwen2Attention):
|
|
|
269
269
|
attention_mask: Optional[torch.Tensor],
|
|
270
270
|
past_key_values: Optional[Cache] = None,
|
|
271
271
|
cache_position: Optional[torch.LongTensor] = None,
|
|
272
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
273
272
|
**kwargs: Unpack[TransformersKwargs],
|
|
274
273
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
275
274
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -299,7 +298,6 @@ class GptOssAttention(Qwen2Attention):
|
|
|
299
298
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
300
299
|
scaling=self.scaling,
|
|
301
300
|
sliding_window=self.sliding_window,
|
|
302
|
-
position_ids=position_ids,
|
|
303
301
|
s_aux=self.sinks, # diff with Llama
|
|
304
302
|
**kwargs,
|
|
305
303
|
)
|
|
@@ -356,8 +354,6 @@ class GptOssDecoderLayer(LlamaDecoderLayer):
|
|
|
356
354
|
class GptOssPreTrainedModel(LlamaPreTrainedModel):
|
|
357
355
|
_keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
|
|
358
356
|
_supports_sdpa = False
|
|
359
|
-
_supports_flash_attention = False
|
|
360
|
-
_supports_flex_attention = False
|
|
361
357
|
_can_record_outputs = {
|
|
362
358
|
"router_logits": OutputRecorder(GptOssTopKRouter, index=0),
|
|
363
359
|
"hidden_states": GptOssDecoderLayer,
|
|
@@ -14,12 +14,14 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""PyTorch GPT-J model."""
|
|
16
16
|
|
|
17
|
+
import math
|
|
17
18
|
from typing import Optional, Union
|
|
18
19
|
|
|
19
20
|
import torch
|
|
20
21
|
from torch import nn
|
|
21
22
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
22
23
|
|
|
24
|
+
from ... import initialization as init
|
|
23
25
|
from ...activations import ACT2FN
|
|
24
26
|
from ...cache_utils import Cache, DynamicCache
|
|
25
27
|
from ...generation import GenerationMixin
|
|
@@ -77,7 +79,7 @@ class GPTJAttention(nn.Module):
|
|
|
77
79
|
def __init__(self, config, layer_idx=None):
|
|
78
80
|
super().__init__()
|
|
79
81
|
self.config = config
|
|
80
|
-
max_positions = config.max_position_embeddings
|
|
82
|
+
self.max_positions = config.max_position_embeddings
|
|
81
83
|
|
|
82
84
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
|
83
85
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
|
@@ -99,15 +101,17 @@ class GPTJAttention(nn.Module):
|
|
|
99
101
|
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
|
|
100
102
|
f" `num_attention_heads`: {self.num_attention_heads})."
|
|
101
103
|
)
|
|
102
|
-
self.scale_attn =
|
|
104
|
+
self.scale_attn = math.sqrt(self.head_dim)
|
|
103
105
|
|
|
104
106
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
105
107
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
106
108
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
107
109
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
108
110
|
self.rotary_dim = config.rotary_dim
|
|
109
|
-
pos_embd_dim = self.rotary_dim or self.embed_dim
|
|
110
|
-
self.
|
|
111
|
+
self.pos_embd_dim = self.rotary_dim or self.embed_dim
|
|
112
|
+
self.register_buffer(
|
|
113
|
+
"embed_positions", create_sinusoidal_positions(self.max_positions, self.pos_embd_dim), persistent=False
|
|
114
|
+
)
|
|
111
115
|
|
|
112
116
|
def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
|
|
113
117
|
"""
|
|
@@ -334,8 +338,8 @@ class GPTJFlashAttention2(GPTJAttention):
|
|
|
334
338
|
else torch.get_autocast_gpu_dtype()
|
|
335
339
|
)
|
|
336
340
|
# Handle the case where the model is quantized
|
|
337
|
-
elif hasattr(self.config, "
|
|
338
|
-
target_dtype = self.config.
|
|
341
|
+
elif hasattr(self.config, "quantization_config"):
|
|
342
|
+
target_dtype = self.config.dtype
|
|
339
343
|
else:
|
|
340
344
|
target_dtype = self.q_proj.weight.dtype
|
|
341
345
|
|
|
@@ -444,6 +448,11 @@ class GPTJPreTrainedModel(PreTrainedModel):
|
|
|
444
448
|
_supports_flash_attn = True
|
|
445
449
|
_can_compile_fullgraph = True
|
|
446
450
|
|
|
451
|
+
def _init_weights(self, module):
|
|
452
|
+
super()._init_weights(module)
|
|
453
|
+
if isinstance(module, GPTJAttention):
|
|
454
|
+
init.copy_(module.embed_positions, create_sinusoidal_positions(module.max_positions, module.pos_embd_dim))
|
|
455
|
+
|
|
447
456
|
|
|
448
457
|
@auto_docstring
|
|
449
458
|
class GPTJModel(GPTJPreTrainedModel):
|
|
@@ -482,6 +491,7 @@ class GPTJModel(GPTJPreTrainedModel):
|
|
|
482
491
|
output_hidden_states: Optional[bool] = None,
|
|
483
492
|
return_dict: Optional[bool] = None,
|
|
484
493
|
cache_position: Optional[torch.LongTensor] = None,
|
|
494
|
+
**kwargs,
|
|
485
495
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
486
496
|
r"""
|
|
487
497
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
|
|
@@ -819,6 +829,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
|
|
819
829
|
output_attentions: Optional[bool] = None,
|
|
820
830
|
output_hidden_states: Optional[bool] = None,
|
|
821
831
|
return_dict: Optional[bool] = None,
|
|
832
|
+
**kwargs,
|
|
822
833
|
) -> Union[tuple, SequenceClassifierOutputWithPast]:
|
|
823
834
|
r"""
|
|
824
835
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
|
|
@@ -930,6 +941,7 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel):
|
|
|
930
941
|
output_attentions: Optional[bool] = None,
|
|
931
942
|
output_hidden_states: Optional[bool] = None,
|
|
932
943
|
return_dict: Optional[bool] = None,
|
|
944
|
+
**kwargs,
|
|
933
945
|
) -> Union[tuple, QuestionAnsweringModelOutput]:
|
|
934
946
|
r"""
|
|
935
947
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
|
|
@@ -28,7 +28,7 @@ from torch import nn
|
|
|
28
28
|
from ...activations import ACT2FN
|
|
29
29
|
from ...cache_utils import Cache, DynamicCache
|
|
30
30
|
from ...generation import GenerationMixin
|
|
31
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
31
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
32
32
|
from ...masking_utils import create_causal_mask
|
|
33
33
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
34
34
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -36,7 +36,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
36
36
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
37
37
|
from ...processing_utils import Unpack
|
|
38
38
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
39
|
-
from ...utils.generic import check_model_inputs
|
|
39
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
40
40
|
from .configuration_granite import GraniteConfig
|
|
41
41
|
|
|
42
42
|
|
|
@@ -116,6 +116,7 @@ def eager_attention_forward(
|
|
|
116
116
|
return attn_output, attn_weights
|
|
117
117
|
|
|
118
118
|
|
|
119
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
119
120
|
class GraniteAttention(nn.Module):
|
|
120
121
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
121
122
|
|
|
@@ -141,7 +142,6 @@ class GraniteAttention(nn.Module):
|
|
|
141
142
|
self.o_proj = nn.Linear(
|
|
142
143
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
143
144
|
)
|
|
144
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
145
145
|
|
|
146
146
|
def forward(
|
|
147
147
|
self,
|
|
@@ -337,7 +337,7 @@ class GraniteRotaryEmbedding(nn.Module):
|
|
|
337
337
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
338
338
|
|
|
339
339
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
340
|
-
self.original_inv_freq =
|
|
340
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
341
341
|
|
|
342
342
|
@staticmethod
|
|
343
343
|
def compute_default_rope_parameters(
|
|
@@ -376,7 +376,7 @@ class GraniteRotaryEmbedding(nn.Module):
|
|
|
376
376
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
377
377
|
|
|
378
378
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
379
|
-
with
|
|
379
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
380
380
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
381
381
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
382
382
|
cos = emb.cos() * self.attention_scaling
|