transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -26,12 +26,12 @@ from ...activations import ACT2FN
|
|
|
26
26
|
from ...cache_utils import Cache, DynamicCache
|
|
27
27
|
from ...configuration_utils import PreTrainedConfig, layer_type_validation
|
|
28
28
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
29
|
-
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
30
29
|
from ...modeling_outputs import BaseModelOutputWithPast
|
|
31
30
|
from ...modeling_rope_utils import RopeParameters
|
|
32
31
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
33
32
|
from ...processing_utils import Unpack
|
|
34
33
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
34
|
+
from ...utils.generic import check_model_inputs
|
|
35
35
|
from ..auto import AutoModel
|
|
36
36
|
from ..gemma2.configuration_gemma2 import Gemma2Config
|
|
37
37
|
from ..gemma2.modeling_gemma2 import (
|
|
@@ -1474,7 +1474,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
|
|
1474
1474
|
)
|
|
1475
1475
|
|
|
1476
1476
|
def forward(
|
|
1477
|
-
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
|
1477
|
+
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs
|
|
1478
1478
|
) -> tuple[torch.Tensor, torch.BoolTensor]:
|
|
1479
1479
|
"""Encodes a batch of MELs.
|
|
1480
1480
|
|
|
@@ -1742,7 +1742,7 @@ class Gemma3nTextAttention(Gemma3Attention):
|
|
|
1742
1742
|
attention_mask: Optional[torch.Tensor] = None,
|
|
1743
1743
|
past_key_values: Optional[Cache] = None,
|
|
1744
1744
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1745
|
-
**kwargs: Unpack[
|
|
1745
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1746
1746
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
1747
1747
|
input_shape = hidden_states.shape[:-1]
|
|
1748
1748
|
hidden_shape = (*input_shape, -1, self.config.head_dim)
|
|
@@ -1830,10 +1830,8 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
|
|
|
1830
1830
|
attention_mask: Optional[torch.Tensor] = None,
|
|
1831
1831
|
position_ids: Optional[torch.LongTensor] = None,
|
|
1832
1832
|
past_key_values: Optional[Cache] = None,
|
|
1833
|
-
output_attentions: Optional[bool] = False,
|
|
1834
|
-
use_cache: Optional[bool] = False,
|
|
1835
1833
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1836
|
-
**kwargs,
|
|
1834
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1837
1835
|
) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
1838
1836
|
predictions = self.altup.predict(hidden_states)
|
|
1839
1837
|
active_prediction = predictions[self.config.altup_active_idx]
|
|
@@ -1841,14 +1839,12 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
|
|
|
1841
1839
|
active_prediction_normed = self.input_layernorm(active_prediction)
|
|
1842
1840
|
laurel_output = self.laurel(active_prediction_normed)
|
|
1843
1841
|
|
|
1844
|
-
attn,
|
|
1842
|
+
attn, _ = self.self_attn(
|
|
1845
1843
|
hidden_states=active_prediction_normed,
|
|
1846
1844
|
attention_mask=attention_mask,
|
|
1847
1845
|
position_ids=position_ids,
|
|
1848
1846
|
position_embeddings=position_embeddings,
|
|
1849
1847
|
past_key_values=past_key_values,
|
|
1850
|
-
output_attentions=output_attentions,
|
|
1851
|
-
use_cache=use_cache,
|
|
1852
1848
|
cache_position=cache_position,
|
|
1853
1849
|
**kwargs,
|
|
1854
1850
|
)
|
|
@@ -1877,18 +1873,17 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
|
|
|
1877
1873
|
first_prediction = self.post_per_layer_input_norm(first_prediction)
|
|
1878
1874
|
corrected_predictions[1:] += first_prediction
|
|
1879
1875
|
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
if output_attentions:
|
|
1883
|
-
outputs += (self_attn_weights,)
|
|
1884
|
-
|
|
1885
|
-
return outputs
|
|
1876
|
+
return corrected_predictions
|
|
1886
1877
|
|
|
1887
1878
|
|
|
1888
1879
|
class Gemma3nPreTrainedModel(Gemma2PreTrainedModel):
|
|
1889
1880
|
config: Gemma3nConfig
|
|
1890
1881
|
input_modalities = ("image", "text", "audio")
|
|
1891
1882
|
_no_split_modules = ["Gemma3nTextDecoderLayer"]
|
|
1883
|
+
_can_record_outputs = {
|
|
1884
|
+
"hidden_states": Gemma3nTextDecoderLayer,
|
|
1885
|
+
"attentions": Gemma3nTextAttention,
|
|
1886
|
+
}
|
|
1892
1887
|
|
|
1893
1888
|
@torch.no_grad()
|
|
1894
1889
|
def _init_weights(self, module):
|
|
@@ -1976,7 +1971,8 @@ class Gemma3nTextModel(Gemma3TextModel):
|
|
|
1976
1971
|
dtype=inputs_embeds.dtype, device=per_layer_projection.device
|
|
1977
1972
|
)
|
|
1978
1973
|
|
|
1979
|
-
|
|
1974
|
+
# Last hidden states should be before reprojecting, to stay consistent with the other layer outputs
|
|
1975
|
+
@check_model_inputs(tie_last_hidden_states=False)
|
|
1980
1976
|
@auto_docstring
|
|
1981
1977
|
def forward(
|
|
1982
1978
|
self,
|
|
@@ -1987,8 +1983,6 @@ class Gemma3nTextModel(Gemma3TextModel):
|
|
|
1987
1983
|
past_key_values: Optional[Cache] = None,
|
|
1988
1984
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
1989
1985
|
use_cache: Optional[bool] = None,
|
|
1990
|
-
output_attentions: Optional[bool] = None,
|
|
1991
|
-
output_hidden_states: Optional[bool] = None,
|
|
1992
1986
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1993
1987
|
**kwargs: Unpack[TransformersKwargs],
|
|
1994
1988
|
) -> BaseModelOutputWithPast:
|
|
@@ -1996,37 +1990,21 @@ class Gemma3nTextModel(Gemma3TextModel):
|
|
|
1996
1990
|
per_layer_inputs (torch.Tensor, *optional*, defaults to None):
|
|
1997
1991
|
Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
|
|
1998
1992
|
"""
|
|
1999
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
2000
|
-
output_hidden_states = (
|
|
2001
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
2002
|
-
)
|
|
2003
|
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
2004
|
-
|
|
2005
1993
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
2006
1994
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
2007
1995
|
|
|
2008
|
-
if self.gradient_checkpointing and self.training and use_cache:
|
|
2009
|
-
logger.warning_once(
|
|
2010
|
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
2011
|
-
)
|
|
2012
|
-
use_cache = False
|
|
2013
|
-
|
|
2014
1996
|
if input_ids is not None:
|
|
2015
1997
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
2016
1998
|
per_layer_inputs = self.get_per_layer_inputs(input_ids)
|
|
2017
1999
|
|
|
2018
2000
|
per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
|
|
2019
2001
|
|
|
2020
|
-
if use_cache and past_key_values is None
|
|
2002
|
+
if use_cache and past_key_values is None:
|
|
2021
2003
|
past_key_values = DynamicCache(config=self.config)
|
|
2022
2004
|
|
|
2023
2005
|
if cache_position is None:
|
|
2024
2006
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
2025
|
-
cache_position = torch.arange(
|
|
2026
|
-
past_seen_tokens,
|
|
2027
|
-
past_seen_tokens + inputs_embeds.shape[1],
|
|
2028
|
-
device=inputs_embeds.device,
|
|
2029
|
-
)
|
|
2007
|
+
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
|
|
2030
2008
|
|
|
2031
2009
|
if position_ids is None:
|
|
2032
2010
|
position_ids = cache_position.unsqueeze(0)
|
|
@@ -2070,39 +2048,21 @@ class Gemma3nTextModel(Gemma3TextModel):
|
|
|
2070
2048
|
for layer_type in self.config.layer_types:
|
|
2071
2049
|
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
|
2072
2050
|
|
|
2073
|
-
# decoder layers
|
|
2074
|
-
all_hidden_states = () if output_hidden_states else None
|
|
2075
|
-
all_self_attns = () if output_attentions else None
|
|
2076
|
-
|
|
2077
2051
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
2078
|
-
if output_hidden_states:
|
|
2079
|
-
all_hidden_states += (hidden_states,)
|
|
2080
|
-
|
|
2081
2052
|
causal_mask = causal_mask_mapping[decoder_layer.attention_type]
|
|
2082
2053
|
per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :]
|
|
2083
2054
|
|
|
2084
|
-
|
|
2055
|
+
hidden_states = decoder_layer(
|
|
2085
2056
|
hidden_states,
|
|
2086
2057
|
position_embeddings[decoder_layer.attention_type],
|
|
2087
2058
|
per_layer_input,
|
|
2088
2059
|
attention_mask=causal_mask,
|
|
2089
2060
|
position_ids=position_ids,
|
|
2090
2061
|
past_key_values=past_key_values,
|
|
2091
|
-
output_attentions=output_attentions,
|
|
2092
|
-
use_cache=use_cache,
|
|
2093
2062
|
cache_position=cache_position,
|
|
2094
2063
|
**kwargs,
|
|
2095
2064
|
)
|
|
2096
2065
|
|
|
2097
|
-
hidden_states = layer_outputs[0]
|
|
2098
|
-
|
|
2099
|
-
if output_attentions:
|
|
2100
|
-
all_self_attns += (layer_outputs[1],)
|
|
2101
|
-
|
|
2102
|
-
# add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output)
|
|
2103
|
-
if output_hidden_states:
|
|
2104
|
-
all_hidden_states += (hidden_states,)
|
|
2105
|
-
|
|
2106
2066
|
# Per-layer inputs to single output
|
|
2107
2067
|
target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
|
|
2108
2068
|
temp_hidden_states = [hidden_states[0]]
|
|
@@ -2122,8 +2082,6 @@ class Gemma3nTextModel(Gemma3TextModel):
|
|
|
2122
2082
|
return BaseModelOutputWithPast(
|
|
2123
2083
|
last_hidden_state=hidden_states,
|
|
2124
2084
|
past_key_values=past_key_values,
|
|
2125
|
-
hidden_states=all_hidden_states,
|
|
2126
|
-
attentions=all_self_attns,
|
|
2127
2085
|
)
|
|
2128
2086
|
|
|
2129
2087
|
|
|
@@ -2284,7 +2242,7 @@ class Gemma3nModel(PaliGemmaModel):
|
|
|
2284
2242
|
use_cache: Optional[bool] = None,
|
|
2285
2243
|
output_attentions: Optional[bool] = None,
|
|
2286
2244
|
output_hidden_states: Optional[bool] = None,
|
|
2287
|
-
**lm_kwargs,
|
|
2245
|
+
**lm_kwargs: Unpack[TransformersKwargs],
|
|
2288
2246
|
) -> Gemma3nCausalLMOutputWithPast:
|
|
2289
2247
|
r"""
|
|
2290
2248
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -2456,7 +2414,7 @@ class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
|
|
2456
2414
|
output_attentions: Optional[bool] = None,
|
|
2457
2415
|
output_hidden_states: Optional[bool] = None,
|
|
2458
2416
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
2459
|
-
**lm_kwargs,
|
|
2417
|
+
**lm_kwargs: Unpack[TransformersKwargs],
|
|
2460
2418
|
) -> Gemma3nCausalLMOutputWithPast:
|
|
2461
2419
|
r"""
|
|
2462
2420
|
input_features_mask (torch.Tensor, *optional*, defaults to None):
|
|
@@ -827,6 +827,7 @@ class GitVisionModel(GitPreTrainedModel):
|
|
|
827
827
|
output_hidden_states: Optional[bool] = None,
|
|
828
828
|
interpolate_pos_encoding: bool = False,
|
|
829
829
|
return_dict: Optional[bool] = None,
|
|
830
|
+
**kwargs,
|
|
830
831
|
) -> Union[tuple, BaseModelOutput]:
|
|
831
832
|
r"""
|
|
832
833
|
Examples:
|
|
@@ -972,6 +973,7 @@ class GitModel(GitPreTrainedModel):
|
|
|
972
973
|
output_hidden_states: Optional[bool] = None,
|
|
973
974
|
interpolate_pos_encoding: bool = False,
|
|
974
975
|
return_dict: Optional[bool] = None,
|
|
976
|
+
**kwargs,
|
|
975
977
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
|
|
976
978
|
r"""
|
|
977
979
|
Examples:
|
|
@@ -28,7 +28,7 @@ import torch.nn as nn
|
|
|
28
28
|
from ...activations import ACT2FN
|
|
29
29
|
from ...cache_utils import Cache, DynamicCache
|
|
30
30
|
from ...generation import GenerationMixin
|
|
31
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
31
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
32
32
|
from ...masking_utils import create_causal_mask
|
|
33
33
|
from ...modeling_layers import (
|
|
34
34
|
GenericForSequenceClassification,
|
|
@@ -40,7 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
40
40
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
41
41
|
from ...processing_utils import Unpack
|
|
42
42
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
43
|
-
from ...utils.generic import check_model_inputs
|
|
43
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
44
44
|
from .configuration_glm import GlmConfig
|
|
45
45
|
|
|
46
46
|
|
|
@@ -120,7 +120,7 @@ class GlmRotaryEmbedding(nn.Module):
|
|
|
120
120
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
121
121
|
|
|
122
122
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
123
|
-
with
|
|
123
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
124
124
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
125
125
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
126
126
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -216,6 +216,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
216
216
|
return q_embed, k_embed
|
|
217
217
|
|
|
218
218
|
|
|
219
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
219
220
|
class GlmAttention(nn.Module):
|
|
220
221
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
221
222
|
|
|
@@ -239,7 +240,6 @@ class GlmAttention(nn.Module):
|
|
|
239
240
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
240
241
|
)
|
|
241
242
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
|
242
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
243
243
|
|
|
244
244
|
def forward(
|
|
245
245
|
self,
|
|
@@ -28,7 +28,7 @@ import torch.nn as nn
|
|
|
28
28
|
from ...activations import ACT2FN
|
|
29
29
|
from ...cache_utils import Cache, DynamicCache
|
|
30
30
|
from ...generation import GenerationMixin
|
|
31
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
31
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
32
32
|
from ...masking_utils import create_causal_mask
|
|
33
33
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
34
34
|
from ...modeling_layers import (
|
|
@@ -41,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
41
41
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
42
42
|
from ...processing_utils import Unpack
|
|
43
43
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
44
|
-
from ...utils.generic import check_model_inputs
|
|
44
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
45
45
|
from .configuration_glm4 import Glm4Config
|
|
46
46
|
|
|
47
47
|
|
|
@@ -198,6 +198,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
198
198
|
return q_embed, k_embed
|
|
199
199
|
|
|
200
200
|
|
|
201
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
201
202
|
class Glm4Attention(nn.Module):
|
|
202
203
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
203
204
|
|
|
@@ -221,7 +222,6 @@ class Glm4Attention(nn.Module):
|
|
|
221
222
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
222
223
|
)
|
|
223
224
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
|
224
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
225
225
|
|
|
226
226
|
def forward(
|
|
227
227
|
self,
|
|
@@ -325,7 +325,7 @@ class Glm4RotaryEmbedding(nn.Module):
|
|
|
325
325
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
326
326
|
|
|
327
327
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
328
|
-
with
|
|
328
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
329
329
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
330
330
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
331
331
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -30,7 +30,7 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
33
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask
|
|
35
35
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
36
36
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -39,7 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
39
39
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
40
40
|
from ...processing_utils import Unpack
|
|
41
41
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
42
|
-
from ...utils.generic import check_model_inputs
|
|
42
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
43
43
|
from .configuration_glm4_moe import Glm4MoeConfig
|
|
44
44
|
|
|
45
45
|
|
|
@@ -101,7 +101,7 @@ class Glm4MoeRotaryEmbedding(nn.Module):
|
|
|
101
101
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
102
102
|
|
|
103
103
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
104
|
-
with
|
|
104
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
105
105
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
106
106
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
107
107
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -193,6 +193,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
193
193
|
return q_embed, k_embed
|
|
194
194
|
|
|
195
195
|
|
|
196
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
196
197
|
class Glm4MoeAttention(nn.Module):
|
|
197
198
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
198
199
|
|
|
@@ -491,6 +492,7 @@ class Glm4MoePreTrainedModel(PreTrainedModel):
|
|
|
491
492
|
"hidden_states": Glm4MoeDecoderLayer,
|
|
492
493
|
"attentions": Glm4MoeAttention,
|
|
493
494
|
}
|
|
495
|
+
_keep_in_fp32_modules_strict = ["e_score_correction_bias"]
|
|
494
496
|
|
|
495
497
|
@torch.no_grad()
|
|
496
498
|
def _init_weights(self, module):
|
|
@@ -234,7 +234,9 @@ class Glm4vTextConfig(PreTrainedConfig):
|
|
|
234
234
|
self.attention_dropout = attention_dropout
|
|
235
235
|
self.rope_parameters = rope_parameters
|
|
236
236
|
|
|
237
|
-
super().__init__(
|
|
237
|
+
super().__init__(
|
|
238
|
+
tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
|
|
239
|
+
)
|
|
238
240
|
|
|
239
241
|
|
|
240
242
|
class Glm4vConfig(PreTrainedConfig):
|
|
@@ -40,7 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
40
40
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
41
41
|
from ...processing_utils import Unpack
|
|
42
42
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
|
|
43
|
-
from ...utils.generic import check_model_inputs
|
|
43
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
44
44
|
from .configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig
|
|
45
45
|
|
|
46
46
|
|
|
@@ -446,7 +446,7 @@ class Glm4vTextRotaryEmbedding(nn.Module):
|
|
|
446
446
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
|
447
447
|
|
|
448
448
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
449
|
-
with
|
|
449
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
450
450
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
|
451
451
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
452
452
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -768,7 +768,7 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
|
|
768
768
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
|
769
769
|
return rotary_pos_emb, pos_ids
|
|
770
770
|
|
|
771
|
-
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
|
771
|
+
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
772
772
|
"""
|
|
773
773
|
Args:
|
|
774
774
|
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
|
@@ -36,7 +36,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
|
36
36
|
from ...processing_utils import Unpack
|
|
37
37
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
|
38
38
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
|
39
|
-
from ...utils.generic import check_model_inputs
|
|
39
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
40
40
|
from ...video_utils import VideoInput
|
|
41
41
|
from ..glm4.modeling_glm4 import Glm4MLP, Glm4RMSNorm, Glm4RotaryEmbedding, eager_attention_forward
|
|
42
42
|
from ..qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
@@ -271,7 +271,9 @@ class Glm4vTextConfig(PreTrainedConfig):
|
|
|
271
271
|
self.attention_dropout = attention_dropout
|
|
272
272
|
self.rope_parameters = rope_parameters
|
|
273
273
|
|
|
274
|
-
super().__init__(
|
|
274
|
+
super().__init__(
|
|
275
|
+
tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
|
|
276
|
+
)
|
|
275
277
|
|
|
276
278
|
|
|
277
279
|
class Glm4vConfig(PreTrainedConfig):
|
|
@@ -509,7 +511,7 @@ class Glm4vTextRotaryEmbedding(Glm4RotaryEmbedding):
|
|
|
509
511
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
|
510
512
|
|
|
511
513
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
512
|
-
with
|
|
514
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
513
515
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
|
514
516
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
515
517
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -786,7 +788,7 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
|
|
786
788
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
|
787
789
|
return rotary_pos_emb, pos_ids
|
|
788
790
|
|
|
789
|
-
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
|
791
|
+
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
790
792
|
"""
|
|
791
793
|
Args:
|
|
792
794
|
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
|
@@ -280,7 +280,9 @@ class Glm4vMoeTextConfig(PreTrainedConfig):
|
|
|
280
280
|
self.first_k_dense_replace = first_k_dense_replace
|
|
281
281
|
self.norm_topk_prob = norm_topk_prob
|
|
282
282
|
self.router_aux_loss_coef = router_aux_loss_coef
|
|
283
|
-
super().__init__(
|
|
283
|
+
super().__init__(
|
|
284
|
+
tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
|
|
285
|
+
)
|
|
284
286
|
|
|
285
287
|
|
|
286
288
|
class Glm4vMoeConfig(PreTrainedConfig):
|
|
@@ -32,7 +32,7 @@ from ... import initialization as init
|
|
|
32
32
|
from ...activations import ACT2FN
|
|
33
33
|
from ...cache_utils import Cache, DynamicCache
|
|
34
34
|
from ...generation import GenerationMixin
|
|
35
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
35
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
36
36
|
from ...masking_utils import create_causal_mask
|
|
37
37
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
38
38
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -41,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
41
41
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
42
42
|
from ...processing_utils import Unpack
|
|
43
43
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
|
|
44
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
44
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
45
45
|
from .configuration_glm4v_moe import Glm4vMoeConfig, Glm4vMoeTextConfig, Glm4vMoeVisionConfig
|
|
46
46
|
|
|
47
47
|
|
|
@@ -150,7 +150,7 @@ class Glm4vMoeTextRotaryEmbedding(nn.Module):
|
|
|
150
150
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
|
151
151
|
|
|
152
152
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
153
|
-
with
|
|
153
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
154
154
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
|
155
155
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
156
156
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -299,6 +299,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
|
|
|
299
299
|
return q_embed, k_embed
|
|
300
300
|
|
|
301
301
|
|
|
302
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
302
303
|
class Glm4vMoeTextAttention(nn.Module):
|
|
303
304
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
304
305
|
|
|
@@ -322,7 +323,6 @@ class Glm4vMoeTextAttention(nn.Module):
|
|
|
322
323
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
323
324
|
)
|
|
324
325
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
|
325
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
326
326
|
self.rope_parameters = config.rope_parameters
|
|
327
327
|
|
|
328
328
|
def forward(
|
|
@@ -594,6 +594,7 @@ class Glm4vMoePreTrainedModel(PreTrainedModel):
|
|
|
594
594
|
"attentions": Glm4vMoeTextAttention,
|
|
595
595
|
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
|
|
596
596
|
}
|
|
597
|
+
_keep_in_fp32_modules_strict = ["e_score_correction_bias"]
|
|
597
598
|
input_modalities = ("text", "image", "video")
|
|
598
599
|
|
|
599
600
|
@torch.no_grad()
|
|
@@ -975,7 +976,7 @@ class Glm4vMoeVisionModel(Glm4vMoePreTrainedModel):
|
|
|
975
976
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
|
976
977
|
return rotary_pos_emb, pos_ids
|
|
977
978
|
|
|
978
|
-
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
|
979
|
+
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
979
980
|
"""
|
|
980
981
|
Args:
|
|
981
982
|
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
|
@@ -227,7 +227,7 @@ class Glm4vMoeTextConfig(Glm4MoeConfig, RotaryEmbeddingConfigMixin):
|
|
|
227
227
|
self.norm_topk_prob = norm_topk_prob
|
|
228
228
|
self.router_aux_loss_coef = router_aux_loss_coef
|
|
229
229
|
PreTrainedConfig.__init__(
|
|
230
|
-
self, tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"
|
|
230
|
+
self, tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
|
|
231
231
|
)
|
|
232
232
|
|
|
233
233
|
|
|
@@ -411,6 +411,7 @@ class GLPNModel(GLPNPreTrainedModel):
|
|
|
411
411
|
output_attentions: Optional[bool] = None,
|
|
412
412
|
output_hidden_states: Optional[bool] = None,
|
|
413
413
|
return_dict: Optional[bool] = None,
|
|
414
|
+
**kwargs,
|
|
414
415
|
) -> Union[tuple, BaseModelOutput]:
|
|
415
416
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
416
417
|
output_hidden_states = (
|
|
@@ -597,6 +598,7 @@ class GLPNForDepthEstimation(GLPNPreTrainedModel):
|
|
|
597
598
|
output_attentions: Optional[bool] = None,
|
|
598
599
|
output_hidden_states: Optional[bool] = None,
|
|
599
600
|
return_dict: Optional[bool] = None,
|
|
601
|
+
**kwargs,
|
|
600
602
|
) -> Union[tuple[torch.Tensor], DepthEstimatorOutput]:
|
|
601
603
|
r"""
|
|
602
604
|
labels (`torch.FloatTensor` of shape `(batch_size, height, width)`, *optional*):
|
|
@@ -45,6 +45,7 @@ from ...utils import (
|
|
|
45
45
|
auto_docstring,
|
|
46
46
|
logging,
|
|
47
47
|
)
|
|
48
|
+
from ...utils.generic import maybe_autocast
|
|
48
49
|
from .configuration_gpt2 import GPT2Config
|
|
49
50
|
|
|
50
51
|
|
|
@@ -150,7 +151,7 @@ class GPT2Attention(nn.Module):
|
|
|
150
151
|
scale_factor /= float(self.layer_idx + 1)
|
|
151
152
|
|
|
152
153
|
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
|
153
|
-
with
|
|
154
|
+
with maybe_autocast(query.device.type, enabled=False):
|
|
154
155
|
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
|
155
156
|
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
|
156
157
|
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
|
@@ -1021,6 +1022,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
|
|
1021
1022
|
output_attentions: Optional[bool] = None,
|
|
1022
1023
|
output_hidden_states: Optional[bool] = None,
|
|
1023
1024
|
return_dict: Optional[bool] = None,
|
|
1025
|
+
**kwargs,
|
|
1024
1026
|
) -> Union[tuple, SequenceClassifierOutputWithPast]:
|
|
1025
1027
|
r"""
|
|
1026
1028
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -1148,6 +1150,7 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
|
|
|
1148
1150
|
output_attentions: Optional[bool] = None,
|
|
1149
1151
|
output_hidden_states: Optional[bool] = None,
|
|
1150
1152
|
return_dict: Optional[bool] = None,
|
|
1153
|
+
**kwargs,
|
|
1151
1154
|
) -> Union[tuple, TokenClassifierOutput]:
|
|
1152
1155
|
r"""
|
|
1153
1156
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
@@ -1228,6 +1231,7 @@ class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
|
|
|
1228
1231
|
output_attentions: Optional[bool] = None,
|
|
1229
1232
|
output_hidden_states: Optional[bool] = None,
|
|
1230
1233
|
return_dict: Optional[bool] = None,
|
|
1234
|
+
**kwargs,
|
|
1231
1235
|
) -> Union[tuple, QuestionAnsweringModelOutput]:
|
|
1232
1236
|
r"""
|
|
1233
1237
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|