transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
2
|
+
# This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py.
|
|
3
|
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
4
|
+
# the file from the modular. If any change should be done, please apply the change to the
|
|
5
|
+
# modular_paddleocr_vl.py file directly. One of our CI enforces this.
|
|
6
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
7
|
+
# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
|
|
8
|
+
#
|
|
9
|
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
10
|
+
# and OPT implementations in this library. It has been modified from its
|
|
11
|
+
# original forms to accommodate minor architectural differences compared
|
|
12
|
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
13
|
+
#
|
|
14
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
15
|
+
# you may not use this file except in compliance with the License.
|
|
16
|
+
# You may obtain a copy of the License at
|
|
17
|
+
#
|
|
18
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
19
|
+
#
|
|
20
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
21
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
22
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
23
|
+
# See the License for the specific language governing permissions and
|
|
24
|
+
# limitations under the License.
|
|
25
|
+
|
|
26
|
+
from typing import Union
|
|
27
|
+
|
|
28
|
+
from ...image_processing_utils import BatchFeature
|
|
29
|
+
from ...image_utils import ImageInput
|
|
30
|
+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
|
31
|
+
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PaddleOCRVLProcessorKwargs(ProcessingKwargs, total=False):
|
|
35
|
+
_defaults = {
|
|
36
|
+
"text_kwargs": {
|
|
37
|
+
"padding": False,
|
|
38
|
+
},
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PaddleOCRVLProcessor(ProcessorMixin):
|
|
43
|
+
r"""
|
|
44
|
+
[`PaddleOCRVLProcessor`] offers all the functionalities of [`PaddleOCRVLImageProcessor`] and [`LLamaTokenizerFast`]. See the
|
|
45
|
+
[`~PaddleOCRVLProcessor.__call__`] and [`~PaddleOCRVLProcessor.decode`] for more information.
|
|
46
|
+
Args:
|
|
47
|
+
image_processor ([`PaddleOCRVLImageProcessor`], *optional*):
|
|
48
|
+
The image processor is a required input.
|
|
49
|
+
tokenizer ([`LLamaTokenizerFast`], *optional*):
|
|
50
|
+
The tokenizer is a required input.
|
|
51
|
+
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
|
52
|
+
in a chat into a tokenizable string.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
image_processor_class = "AutoImageProcessor"
|
|
56
|
+
tokenizer_class = "AutoTokenizer"
|
|
57
|
+
|
|
58
|
+
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
|
59
|
+
self.image_token = tokenizer.image_token
|
|
60
|
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
|
61
|
+
|
|
62
|
+
def __call__(
|
|
63
|
+
self,
|
|
64
|
+
images: ImageInput = None,
|
|
65
|
+
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
|
|
66
|
+
**kwargs: Unpack[PaddleOCRVLProcessorKwargs],
|
|
67
|
+
) -> BatchFeature:
|
|
68
|
+
"""
|
|
69
|
+
Args:
|
|
70
|
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
71
|
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
72
|
+
tensor. Both channels-first and channels-last formats are supported.
|
|
73
|
+
text (`str`, `List[str]`, `List[List[str]]`):
|
|
74
|
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
|
75
|
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
|
76
|
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
|
77
|
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
|
78
|
+
If set, will return tensors of a particular framework. Acceptable values are:
|
|
79
|
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
80
|
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
81
|
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
82
|
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
86
|
+
|
|
87
|
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
|
88
|
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
|
89
|
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
|
90
|
+
`None`).
|
|
91
|
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
|
92
|
+
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
|
93
|
+
"""
|
|
94
|
+
output_kwargs = self._merge_kwargs(
|
|
95
|
+
PaddleOCRVLProcessorKwargs,
|
|
96
|
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
97
|
+
**kwargs,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if images is not None:
|
|
101
|
+
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
|
102
|
+
image_grid_thw = image_inputs["image_grid_thw"]
|
|
103
|
+
|
|
104
|
+
else:
|
|
105
|
+
image_inputs = {}
|
|
106
|
+
image_grid_thw = None
|
|
107
|
+
|
|
108
|
+
if not isinstance(text, list):
|
|
109
|
+
text = [text]
|
|
110
|
+
|
|
111
|
+
text = text.copy()
|
|
112
|
+
|
|
113
|
+
if image_grid_thw is not None:
|
|
114
|
+
index = 0
|
|
115
|
+
for i in range(len(text)):
|
|
116
|
+
while self.image_token in text[i]:
|
|
117
|
+
text[i] = text[i].replace(
|
|
118
|
+
self.image_token,
|
|
119
|
+
"<|placeholder|>"
|
|
120
|
+
* (
|
|
121
|
+
image_grid_thw[index].prod()
|
|
122
|
+
// self.image_processor.merge_size
|
|
123
|
+
// self.image_processor.merge_size
|
|
124
|
+
),
|
|
125
|
+
1,
|
|
126
|
+
)
|
|
127
|
+
index += 1
|
|
128
|
+
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
|
129
|
+
|
|
130
|
+
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
|
131
|
+
|
|
132
|
+
return BatchFeature(data={**text_inputs, **image_inputs})
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
__all__ = ["PaddleOCRVLProcessor"]
|
|
@@ -121,9 +121,6 @@ class ParakeetEncoderConfig(PreTrainedConfig):
|
|
|
121
121
|
initializer_range=0.02,
|
|
122
122
|
**kwargs,
|
|
123
123
|
):
|
|
124
|
-
super().__init__(
|
|
125
|
-
**kwargs,
|
|
126
|
-
)
|
|
127
124
|
self.hidden_size = hidden_size
|
|
128
125
|
self.num_hidden_layers = num_hidden_layers
|
|
129
126
|
self.num_attention_heads = num_attention_heads
|
|
@@ -133,10 +130,7 @@ class ParakeetEncoderConfig(PreTrainedConfig):
|
|
|
133
130
|
self.attention_bias = attention_bias
|
|
134
131
|
self.convolution_bias = convolution_bias
|
|
135
132
|
|
|
136
|
-
if (conv_kernel_size - 1) % 2 != 0:
|
|
137
|
-
raise ValueError(f"conv_kernel_size must be odd, got {conv_kernel_size}")
|
|
138
133
|
self.conv_kernel_size = conv_kernel_size
|
|
139
|
-
|
|
140
134
|
self.subsampling_conv_kernel_size = subsampling_conv_kernel_size
|
|
141
135
|
self.subsampling_conv_stride = subsampling_conv_stride
|
|
142
136
|
|
|
@@ -153,6 +147,10 @@ class ParakeetEncoderConfig(PreTrainedConfig):
|
|
|
153
147
|
self.scale_input = scale_input
|
|
154
148
|
self.initializer_range = initializer_range
|
|
155
149
|
|
|
150
|
+
super().__init__(
|
|
151
|
+
**kwargs,
|
|
152
|
+
)
|
|
153
|
+
|
|
156
154
|
|
|
157
155
|
class ParakeetCTCConfig(PreTrainedConfig):
|
|
158
156
|
r"""
|
|
@@ -29,13 +29,13 @@ from torch import nn
|
|
|
29
29
|
|
|
30
30
|
from ... import initialization as init
|
|
31
31
|
from ...activations import ACT2FN
|
|
32
|
-
from ...integrations import use_kernel_func_from_hub
|
|
32
|
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
|
33
33
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
34
34
|
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
|
|
35
35
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
36
36
|
from ...processing_utils import Unpack
|
|
37
37
|
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
|
|
38
|
-
from ...utils.generic import check_model_inputs
|
|
38
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
39
39
|
from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
|
|
40
40
|
|
|
41
41
|
|
|
@@ -88,7 +88,7 @@ class ParakeetEncoderRelPositionalEncoding(nn.Module):
|
|
|
88
88
|
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
|
|
89
89
|
else "cpu"
|
|
90
90
|
)
|
|
91
|
-
with
|
|
91
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
92
92
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
93
93
|
sin = freqs.sin()
|
|
94
94
|
cos = freqs.cos()
|
|
@@ -155,7 +155,7 @@ class ParakeetEncoderConvolutionModule(nn.Module):
|
|
|
155
155
|
|
|
156
156
|
Args:
|
|
157
157
|
hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
|
|
158
|
-
attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
|
|
158
|
+
attention_mask (`torch.Tensor` of shape `(batch, 1, time, time)`): Attention mask.
|
|
159
159
|
|
|
160
160
|
Returns:
|
|
161
161
|
`torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
|
|
@@ -171,7 +171,10 @@ class ParakeetEncoderConvolutionModule(nn.Module):
|
|
|
171
171
|
|
|
172
172
|
# Apply padding mask before convolution
|
|
173
173
|
if attention_mask is not None:
|
|
174
|
-
|
|
174
|
+
if attention_mask.dtype == torch.bool:
|
|
175
|
+
all_masked_rows = torch.all(~attention_mask, dim=2)
|
|
176
|
+
else:
|
|
177
|
+
all_masked_rows = torch.all(~(attention_mask == 0.0), dim=2)
|
|
175
178
|
hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
|
|
176
179
|
|
|
177
180
|
# 1D Depthwise Conv
|
|
@@ -256,6 +259,7 @@ def eager_attention_forward(
|
|
|
256
259
|
return attn_output, attn_weights
|
|
257
260
|
|
|
258
261
|
|
|
262
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
259
263
|
class ParakeetEncoderAttention(nn.Module):
|
|
260
264
|
"""Multi-head attention with relative positional encoding. See section 3.3 of https://huggingface.co/papers/1901.02860."""
|
|
261
265
|
|
|
@@ -281,7 +285,6 @@ class ParakeetEncoderAttention(nn.Module):
|
|
|
281
285
|
self.o_proj = nn.Linear(
|
|
282
286
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
283
287
|
)
|
|
284
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
285
288
|
# W_{k,R} projection
|
|
286
289
|
self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
|
287
290
|
# global content bias
|
|
@@ -29,7 +29,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput
|
|
|
29
29
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
30
30
|
from ...processing_utils import Unpack
|
|
31
31
|
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
|
|
32
|
-
from ...utils.generic import check_model_inputs
|
|
32
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
33
33
|
from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule
|
|
34
34
|
from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
|
|
35
35
|
from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
|
|
@@ -84,7 +84,7 @@ class ParakeetEncoderRelPositionalEncoding(nn.Module):
|
|
|
84
84
|
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
|
|
85
85
|
else "cpu"
|
|
86
86
|
)
|
|
87
|
-
with
|
|
87
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
88
88
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
89
89
|
sin = freqs.sin()
|
|
90
90
|
cos = freqs.cos()
|
|
@@ -1141,6 +1141,7 @@ class PatchTSMixerEncoder(PatchTSMixerPreTrainedModel):
|
|
|
1141
1141
|
past_values: torch.Tensor,
|
|
1142
1142
|
output_hidden_states: Optional[bool] = False,
|
|
1143
1143
|
return_dict: Optional[bool] = None,
|
|
1144
|
+
**kwargs,
|
|
1144
1145
|
) -> Union[tuple, PatchTSMixerEncoderOutput]:
|
|
1145
1146
|
r"""
|
|
1146
1147
|
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
|
|
@@ -1251,6 +1252,7 @@ class PatchTSMixerModel(PatchTSMixerPreTrainedModel):
|
|
|
1251
1252
|
observed_mask: Optional[torch.Tensor] = None,
|
|
1252
1253
|
output_hidden_states: Optional[bool] = False,
|
|
1253
1254
|
return_dict: Optional[bool] = None,
|
|
1255
|
+
**kwargs,
|
|
1254
1256
|
) -> PatchTSMixerModelOutput:
|
|
1255
1257
|
r"""
|
|
1256
1258
|
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
|
|
@@ -1362,6 +1364,7 @@ class PatchTSMixerForPretraining(PatchTSMixerPreTrainedModel):
|
|
|
1362
1364
|
output_hidden_states: Optional[bool] = False,
|
|
1363
1365
|
return_loss: bool = True,
|
|
1364
1366
|
return_dict: Optional[bool] = None,
|
|
1367
|
+
**kwargs,
|
|
1365
1368
|
) -> PatchTSMixerForPreTrainingOutput:
|
|
1366
1369
|
r"""
|
|
1367
1370
|
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
|
|
@@ -1574,6 +1577,7 @@ class PatchTSMixerForPrediction(PatchTSMixerPreTrainedModel):
|
|
|
1574
1577
|
output_hidden_states: Optional[bool] = False,
|
|
1575
1578
|
return_loss: bool = True,
|
|
1576
1579
|
return_dict: Optional[bool] = None,
|
|
1580
|
+
**kwargs,
|
|
1577
1581
|
) -> PatchTSMixerForPredictionOutput:
|
|
1578
1582
|
r"""
|
|
1579
1583
|
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
|
|
@@ -1797,6 +1801,7 @@ class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
|
|
|
1797
1801
|
output_hidden_states: Optional[bool] = False,
|
|
1798
1802
|
return_loss: bool = True,
|
|
1799
1803
|
return_dict: Optional[bool] = None,
|
|
1804
|
+
**kwargs,
|
|
1800
1805
|
) -> PatchTSMixerForTimeSeriesClassificationOutput:
|
|
1801
1806
|
r"""
|
|
1802
1807
|
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
|
|
@@ -1987,6 +1992,7 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
|
|
|
1987
1992
|
output_hidden_states: Optional[bool] = False,
|
|
1988
1993
|
return_loss: bool = True,
|
|
1989
1994
|
return_dict: Optional[bool] = None,
|
|
1995
|
+
**kwargs,
|
|
1990
1996
|
) -> PatchTSMixerForRegressionOutput:
|
|
1991
1997
|
r"""
|
|
1992
1998
|
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
|
|
@@ -24,6 +24,7 @@ from torch import nn
|
|
|
24
24
|
|
|
25
25
|
from ... import initialization as init
|
|
26
26
|
from ...activations import ACT2CLS
|
|
27
|
+
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
27
28
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
28
29
|
from ...modeling_outputs import BaseModelOutput
|
|
29
30
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
@@ -418,7 +419,7 @@ class PatchTSTEncoderLayer(nn.Module):
|
|
|
418
419
|
super().__init__()
|
|
419
420
|
|
|
420
421
|
self.channel_attention = config.channel_attention
|
|
421
|
-
|
|
422
|
+
|
|
422
423
|
self.self_attn = PatchTSTAttention(
|
|
423
424
|
embed_dim=config.d_model,
|
|
424
425
|
num_heads=config.num_attention_heads,
|
|
@@ -555,6 +556,9 @@ class PatchTSTPreTrainedModel(PreTrainedModel):
|
|
|
555
556
|
main_input_name = "past_values"
|
|
556
557
|
input_modalities = ("time",)
|
|
557
558
|
supports_gradient_checkpointing = False
|
|
559
|
+
_supports_flash_attn = True
|
|
560
|
+
_supports_sdpa = True
|
|
561
|
+
_supports_flex_attn = True
|
|
558
562
|
|
|
559
563
|
@torch.no_grad()
|
|
560
564
|
def _init_weights(self, module: nn.Module):
|
|
@@ -571,7 +575,15 @@ class PatchTSTPreTrainedModel(PreTrainedModel):
|
|
|
571
575
|
init.normal_(module.cls_token, std=0.02)
|
|
572
576
|
num_patches += 1
|
|
573
577
|
# initialize positional encoding
|
|
574
|
-
|
|
578
|
+
position_enc = module._init_pe(self.config, num_patches)
|
|
579
|
+
if is_deepspeed_zero3_enabled():
|
|
580
|
+
import deepspeed
|
|
581
|
+
|
|
582
|
+
with deepspeed.zero.GatheredParameters(module.position_enc, modifier_rank=None):
|
|
583
|
+
if module.position_enc.numel() > 0:
|
|
584
|
+
init.copy_(module.position_enc, position_enc)
|
|
585
|
+
else:
|
|
586
|
+
init.copy_(module.position_enc, position_enc)
|
|
575
587
|
elif isinstance(module, nn.LayerNorm):
|
|
576
588
|
init.zeros_(module.bias)
|
|
577
589
|
init.ones_(module.weight)
|
|
@@ -704,6 +716,7 @@ class PatchTSTEncoder(PatchTSTPreTrainedModel):
|
|
|
704
716
|
patch_input: torch.Tensor,
|
|
705
717
|
output_hidden_states: Optional[bool] = None,
|
|
706
718
|
output_attentions: Optional[bool] = None,
|
|
719
|
+
**kwargs,
|
|
707
720
|
) -> BaseModelOutput:
|
|
708
721
|
"""
|
|
709
722
|
Parameters:
|
|
@@ -1092,6 +1105,7 @@ class PatchTSTModel(PatchTSTPreTrainedModel):
|
|
|
1092
1105
|
output_hidden_states: Optional[bool] = None,
|
|
1093
1106
|
output_attentions: Optional[bool] = None,
|
|
1094
1107
|
return_dict: Optional[bool] = None,
|
|
1108
|
+
**kwargs,
|
|
1095
1109
|
) -> Union[tuple, PatchTSTModelOutput]:
|
|
1096
1110
|
r"""
|
|
1097
1111
|
Parameters:
|
|
@@ -1228,6 +1242,7 @@ class PatchTSTForPretraining(PatchTSTPreTrainedModel):
|
|
|
1228
1242
|
output_hidden_states: Optional[bool] = None,
|
|
1229
1243
|
output_attentions: Optional[bool] = None,
|
|
1230
1244
|
return_dict: Optional[bool] = None,
|
|
1245
|
+
**kwargs,
|
|
1231
1246
|
) -> Union[tuple, PatchTSTForPretrainingOutput]:
|
|
1232
1247
|
r"""
|
|
1233
1248
|
Parameters:
|
|
@@ -1387,6 +1402,7 @@ class PatchTSTForClassification(PatchTSTPreTrainedModel):
|
|
|
1387
1402
|
output_hidden_states: Optional[bool] = None,
|
|
1388
1403
|
output_attentions: Optional[bool] = None,
|
|
1389
1404
|
return_dict: Optional[bool] = None,
|
|
1405
|
+
**kwargs,
|
|
1390
1406
|
) -> Union[tuple, PatchTSTForClassificationOutput]:
|
|
1391
1407
|
r"""
|
|
1392
1408
|
past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
|
|
@@ -1594,6 +1610,7 @@ class PatchTSTForPrediction(PatchTSTPreTrainedModel):
|
|
|
1594
1610
|
output_hidden_states: Optional[bool] = None,
|
|
1595
1611
|
output_attentions: Optional[bool] = None,
|
|
1596
1612
|
return_dict: Optional[bool] = None,
|
|
1613
|
+
**kwargs,
|
|
1597
1614
|
) -> Union[tuple, PatchTSTForPredictionOutput]:
|
|
1598
1615
|
r"""
|
|
1599
1616
|
Parameters:
|
|
@@ -1840,6 +1857,7 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
|
|
|
1840
1857
|
output_hidden_states: Optional[bool] = None,
|
|
1841
1858
|
output_attentions: Optional[bool] = None,
|
|
1842
1859
|
return_dict: Optional[bool] = None,
|
|
1860
|
+
**kwargs,
|
|
1843
1861
|
) -> Union[tuple, PatchTSTForRegressionOutput]:
|
|
1844
1862
|
r"""
|
|
1845
1863
|
past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
|
|
@@ -518,6 +518,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
|
|
|
518
518
|
output_attentions=None,
|
|
519
519
|
output_hidden_states=None,
|
|
520
520
|
return_dict=None,
|
|
521
|
+
**kwargs,
|
|
521
522
|
):
|
|
522
523
|
r"""
|
|
523
524
|
Args:
|
|
@@ -695,6 +696,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|
|
695
696
|
output_hidden_states=None,
|
|
696
697
|
return_dict=None,
|
|
697
698
|
cache_position=None,
|
|
699
|
+
**kwargs,
|
|
698
700
|
):
|
|
699
701
|
r"""
|
|
700
702
|
Args:
|
|
@@ -946,6 +948,7 @@ class PegasusModel(PegasusPreTrainedModel):
|
|
|
946
948
|
output_hidden_states: Optional[bool] = None,
|
|
947
949
|
return_dict: Optional[bool] = None,
|
|
948
950
|
cache_position: Optional[torch.Tensor] = None,
|
|
951
|
+
**kwargs,
|
|
949
952
|
) -> Union[tuple, Seq2SeqModelOutput]:
|
|
950
953
|
r"""
|
|
951
954
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1111,6 +1114,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin):
|
|
|
1111
1114
|
output_hidden_states: Optional[bool] = None,
|
|
1112
1115
|
return_dict: Optional[bool] = None,
|
|
1113
1116
|
cache_position: Optional[torch.Tensor] = None,
|
|
1117
|
+
**kwargs,
|
|
1114
1118
|
) -> Union[tuple, Seq2SeqLMOutput]:
|
|
1115
1119
|
r"""
|
|
1116
1120
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1283,6 +1287,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin):
|
|
|
1283
1287
|
return_dict: Optional[bool] = None,
|
|
1284
1288
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1285
1289
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1290
|
+
**kwargs,
|
|
1286
1291
|
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
|
|
1287
1292
|
r"""
|
|
1288
1293
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Tokenization class for model PEGASUS."""
|
|
16
16
|
|
|
17
|
+
from typing import Optional, Union
|
|
18
|
+
|
|
17
19
|
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
|
18
20
|
from tokenizers.models import Unigram
|
|
19
21
|
|
|
@@ -70,15 +72,17 @@ class PegasusTokenizer(TokenizersBackend):
|
|
|
70
72
|
that uses the tokens 2 - 104 only for pretraining
|
|
71
73
|
offset (`int`, *optional*, defaults to 103):
|
|
72
74
|
Offset for additional special tokens.
|
|
73
|
-
vocab (`
|
|
74
|
-
Custom vocabulary
|
|
75
|
+
vocab (`str` or `list[tuple[str, float]]`, *optional*):
|
|
76
|
+
Custom vocabulary with `(token, score)` tuples. If not provided, a blank vocabulary is initialized.
|
|
75
77
|
"""
|
|
76
78
|
|
|
77
79
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
78
80
|
model_input_names = ["input_ids", "attention_mask"]
|
|
81
|
+
model = Unigram
|
|
79
82
|
|
|
80
83
|
def __init__(
|
|
81
84
|
self,
|
|
85
|
+
vocab: Optional[Union[str, list[tuple[str, float]]]] = None,
|
|
82
86
|
pad_token="<pad>",
|
|
83
87
|
eos_token="</s>",
|
|
84
88
|
unk_token="<unk>",
|
|
@@ -86,60 +90,27 @@ class PegasusTokenizer(TokenizersBackend):
|
|
|
86
90
|
mask_token_sent="<mask_1>",
|
|
87
91
|
additional_special_tokens=None,
|
|
88
92
|
offset=103,
|
|
89
|
-
vocab=None,
|
|
90
|
-
vocab_file=None,
|
|
91
93
|
**kwargs,
|
|
92
94
|
):
|
|
93
95
|
self.offset = offset
|
|
94
|
-
self.vocab_file = vocab_file
|
|
95
96
|
|
|
96
97
|
if additional_special_tokens is None:
|
|
97
98
|
additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else []
|
|
98
99
|
additional_special_tokens += [f"<unk_{i}>" for i in range(2, self.offset)]
|
|
99
100
|
|
|
100
|
-
if vocab is
|
|
101
|
-
|
|
102
|
-
special_tokens_set = {pad_token, eos_token, mask_token_sent, mask_token, unk_token}
|
|
103
|
-
special_tokens_set.update(additional_special_tokens)
|
|
104
|
-
|
|
105
|
-
# Build special tokens in correct order
|
|
106
|
-
_vocab_list = [
|
|
107
|
-
(str(pad_token), 0.0),
|
|
108
|
-
(str(eos_token), 0.0),
|
|
109
|
-
]
|
|
110
|
-
if mask_token_sent:
|
|
111
|
-
_vocab_list.append((str(mask_token_sent), 0.0))
|
|
112
|
-
for token in additional_special_tokens:
|
|
113
|
-
if token not in [pad_token, eos_token, mask_token_sent]:
|
|
114
|
-
_vocab_list.append((str(token), 0.0))
|
|
115
|
-
if mask_token not in [t for t, _ in _vocab_list]:
|
|
116
|
-
_vocab_list.append((str(mask_token), 0.0))
|
|
117
|
-
_vocab_list.append((str(unk_token), 0.0))
|
|
118
|
-
|
|
119
|
-
# Filter out special tokens from main vocab and combine
|
|
120
|
-
filtered_vocab = [(t, s) for t, s in vocab if t not in special_tokens_set]
|
|
121
|
-
_vocab_list = _vocab_list + filtered_vocab
|
|
122
|
-
else:
|
|
123
|
-
_vocab_list = [(str(unk_token), 0.0)]
|
|
124
|
-
|
|
125
|
-
self._vocab = {token: idx for idx, (token, _) in enumerate(_vocab_list)}
|
|
126
|
-
|
|
127
|
-
self._tokenizer = Tokenizer(Unigram(vocab=_vocab_list, unk_id=self._vocab.get(str(unk_token), 0)))
|
|
101
|
+
if vocab is None:
|
|
102
|
+
vocab = [(str(unk_token), 0.0), (str(pad_token), 0.0), (str(eos_token), 0.0), (str(mask_token), 0.0)]
|
|
128
103
|
|
|
104
|
+
self._vocab = vocab
|
|
105
|
+
self._tokenizer = Tokenizer(Unigram(vocab=vocab, unk_id=self._vocab.index((str(unk_token), 0.0), 1)))
|
|
129
106
|
self._tokenizer.normalizer = normalizers.Sequence(
|
|
130
107
|
[normalizers.Replace(Regex(r"\n"), " "), normalizers.Replace(Regex(r" {2,}"), " ")]
|
|
131
108
|
)
|
|
132
109
|
|
|
133
|
-
self._tokenizer.
|
|
134
|
-
|
|
135
|
-
pair=f"$A $B {eos_token}",
|
|
136
|
-
special_tokens=[(str(eos_token), self._vocab.get(str(eos_token), 1))],
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
tokenizer_object = self._tokenizer
|
|
110
|
+
self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True)
|
|
111
|
+
self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
|
|
140
112
|
|
|
141
113
|
super().__init__(
|
|
142
|
-
tokenizer_object=tokenizer_object,
|
|
143
114
|
pad_token=pad_token,
|
|
144
115
|
eos_token=eos_token,
|
|
145
116
|
unk_token=unk_token,
|
|
@@ -149,9 +120,11 @@ class PegasusTokenizer(TokenizersBackend):
|
|
|
149
120
|
additional_special_tokens=additional_special_tokens,
|
|
150
121
|
**kwargs,
|
|
151
122
|
)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
123
|
+
self._tokenizer.post_processor = processors.TemplateProcessing(
|
|
124
|
+
single=f"$A {eos_token}",
|
|
125
|
+
pair=f"$A $B {eos_token}",
|
|
126
|
+
special_tokens=[(str(eos_token), self.convert_tokens_to_ids(str(eos_token)))],
|
|
127
|
+
)
|
|
155
128
|
|
|
156
129
|
|
|
157
130
|
__all__ = ["PegasusTokenizer"]
|
|
@@ -821,6 +821,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel):
|
|
|
821
821
|
output_attentions=None,
|
|
822
822
|
output_hidden_states=None,
|
|
823
823
|
return_dict=None,
|
|
824
|
+
**kwargs,
|
|
824
825
|
):
|
|
825
826
|
r"""
|
|
826
827
|
Args:
|
|
@@ -989,6 +990,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
|
|
|
989
990
|
output_hidden_states=None,
|
|
990
991
|
return_dict=None,
|
|
991
992
|
cache_position=None,
|
|
993
|
+
**kwargs,
|
|
992
994
|
):
|
|
993
995
|
r"""
|
|
994
996
|
Args:
|
|
@@ -1241,6 +1243,7 @@ class PegasusXModel(PegasusXPreTrainedModel):
|
|
|
1241
1243
|
output_hidden_states: Optional[bool] = None,
|
|
1242
1244
|
return_dict: Optional[bool] = None,
|
|
1243
1245
|
cache_position: Optional[torch.Tensor] = None,
|
|
1246
|
+
**kwargs,
|
|
1244
1247
|
) -> Union[tuple, Seq2SeqModelOutput]:
|
|
1245
1248
|
r"""
|
|
1246
1249
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1388,6 +1391,7 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel, GenerationMixin)
|
|
|
1388
1391
|
output_hidden_states: Optional[bool] = None,
|
|
1389
1392
|
return_dict: Optional[bool] = None,
|
|
1390
1393
|
cache_position: Optional[torch.Tensor] = None,
|
|
1394
|
+
**kwargs,
|
|
1391
1395
|
) -> Union[tuple, Seq2SeqLMOutput]:
|
|
1392
1396
|
r"""
|
|
1393
1397
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -615,6 +615,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
|
|
|
615
615
|
output_hidden_states: Optional[bool] = None,
|
|
616
616
|
interpolate_pos_encoding: bool = False,
|
|
617
617
|
return_dict: Optional[bool] = None,
|
|
618
|
+
**kwargs,
|
|
618
619
|
) -> Union[tuple, PerceiverModelOutput]:
|
|
619
620
|
r"""
|
|
620
621
|
inputs (`torch.FloatTensor`):
|
|
@@ -850,6 +851,7 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
|
|
|
850
851
|
labels: Optional[torch.Tensor] = None,
|
|
851
852
|
return_dict: Optional[bool] = None,
|
|
852
853
|
input_ids: Optional[torch.Tensor] = None,
|
|
854
|
+
**kwargs,
|
|
853
855
|
) -> Union[tuple, PerceiverMaskedLMOutput]:
|
|
854
856
|
r"""
|
|
855
857
|
inputs (`torch.FloatTensor`):
|
|
@@ -975,6 +977,7 @@ class PerceiverForSequenceClassification(PerceiverPreTrainedModel):
|
|
|
975
977
|
labels: Optional[torch.Tensor] = None,
|
|
976
978
|
return_dict: Optional[bool] = None,
|
|
977
979
|
input_ids: Optional[torch.Tensor] = None,
|
|
980
|
+
**kwargs,
|
|
978
981
|
) -> Union[tuple, PerceiverClassifierOutput]:
|
|
979
982
|
r"""
|
|
980
983
|
inputs (`torch.FloatTensor`):
|
|
@@ -1107,6 +1110,7 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
|
|
|
1107
1110
|
interpolate_pos_encoding: bool = False,
|
|
1108
1111
|
return_dict: Optional[bool] = None,
|
|
1109
1112
|
pixel_values: Optional[torch.Tensor] = None,
|
|
1113
|
+
**kwargs,
|
|
1110
1114
|
) -> Union[tuple, PerceiverClassifierOutput]:
|
|
1111
1115
|
r"""
|
|
1112
1116
|
inputs (`torch.FloatTensor`):
|
|
@@ -1229,6 +1233,7 @@ class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
|
|
|
1229
1233
|
labels: Optional[torch.Tensor] = None,
|
|
1230
1234
|
return_dict: Optional[bool] = None,
|
|
1231
1235
|
pixel_values: Optional[torch.Tensor] = None,
|
|
1236
|
+
**kwargs,
|
|
1232
1237
|
) -> Union[tuple, PerceiverClassifierOutput]:
|
|
1233
1238
|
r"""
|
|
1234
1239
|
inputs (`torch.FloatTensor`):
|
|
@@ -1350,6 +1355,7 @@ class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
|
|
|
1350
1355
|
labels: Optional[torch.Tensor] = None,
|
|
1351
1356
|
return_dict: Optional[bool] = None,
|
|
1352
1357
|
pixel_values: Optional[torch.Tensor] = None,
|
|
1358
|
+
**kwargs,
|
|
1353
1359
|
) -> Union[tuple, PerceiverClassifierOutput]:
|
|
1354
1360
|
r"""
|
|
1355
1361
|
inputs (`torch.FloatTensor`):
|
|
@@ -1487,6 +1493,7 @@ class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
|
|
|
1487
1493
|
output_hidden_states: Optional[bool] = None,
|
|
1488
1494
|
labels: Optional[torch.Tensor] = None,
|
|
1489
1495
|
return_dict: Optional[bool] = None,
|
|
1496
|
+
**kwargs,
|
|
1490
1497
|
) -> Union[tuple, PerceiverClassifierOutput]:
|
|
1491
1498
|
r"""
|
|
1492
1499
|
inputs (`torch.FloatTensor`):
|
|
@@ -1695,6 +1702,7 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
|
|
|
1695
1702
|
output_hidden_states: Optional[bool] = None,
|
|
1696
1703
|
labels: Optional[torch.Tensor] = None,
|
|
1697
1704
|
return_dict: Optional[bool] = None,
|
|
1705
|
+
**kwargs,
|
|
1698
1706
|
) -> Union[tuple, PerceiverClassifierOutput]:
|
|
1699
1707
|
r"""
|
|
1700
1708
|
inputs (`torch.FloatTensor`):
|
|
@@ -46,6 +46,7 @@ from ...modeling_rope_utils import (
|
|
|
46
46
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
47
47
|
from ...processing_utils import Unpack
|
|
48
48
|
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
|
49
|
+
from ...utils.generic import maybe_autocast
|
|
49
50
|
from .configuration_persimmon import PersimmonConfig
|
|
50
51
|
|
|
51
52
|
|
|
@@ -118,7 +119,7 @@ class PersimmonRotaryEmbedding(nn.Module):
|
|
|
118
119
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
119
120
|
|
|
120
121
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
121
|
-
with
|
|
122
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
122
123
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
123
124
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
124
125
|
cos = emb.cos() * self.attention_scaling
|