transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Tokenization classes for CLIP."""
|
|
16
16
|
|
|
17
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
|
20
20
|
from tokenizers.models import BPE
|
|
@@ -37,6 +37,10 @@ class CLIPTokenizer(TokenizersBackend):
|
|
|
37
37
|
refer to this superclass for more information regarding those methods.
|
|
38
38
|
|
|
39
39
|
Args:
|
|
40
|
+
vocab (`str`, `dict` or `list`, *optional*):
|
|
41
|
+
Vocabulary dict to use for the tokenizer.
|
|
42
|
+
merges (`str` or `list`, *optional*):
|
|
43
|
+
Merges list to use for the BPE tokenizer.
|
|
40
44
|
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
|
41
45
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
|
42
46
|
token instead.
|
|
@@ -46,53 +50,38 @@ class CLIPTokenizer(TokenizersBackend):
|
|
|
46
50
|
The end of sequence token.
|
|
47
51
|
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
|
48
52
|
The token used for padding, for example when batching sequences of different lengths.
|
|
49
|
-
vocab (`dict`, *optional*):
|
|
50
|
-
Vocabulary dict to use for the tokenizer.
|
|
51
|
-
merges (`list`, *optional*):
|
|
52
|
-
Merges list to use for the BPE tokenizer.
|
|
53
|
-
vocab_file (`str`, *optional*):
|
|
54
|
-
Path to the vocabulary file.
|
|
55
|
-
merges_file (`str`, *optional*):
|
|
56
|
-
Path to the merges file.
|
|
57
53
|
"""
|
|
58
54
|
|
|
59
55
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
60
56
|
model_input_names = ["input_ids", "attention_mask"]
|
|
61
|
-
|
|
57
|
+
model = BPE
|
|
62
58
|
|
|
63
59
|
def __init__(
|
|
64
60
|
self,
|
|
61
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
62
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
65
63
|
unk_token: str = "<|endoftext|>",
|
|
66
64
|
bos_token: str = "<|startoftext|>",
|
|
67
65
|
eos_token: str = "<|endoftext|>",
|
|
68
66
|
pad_token: str = "<|endoftext|>",
|
|
69
|
-
vocab: Optional[dict] = None,
|
|
70
|
-
merges: Optional[list] = None,
|
|
71
|
-
vocab_file: Optional[str] = None,
|
|
72
|
-
merges_file: Optional[str] = None,
|
|
73
67
|
**kwargs,
|
|
74
68
|
):
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
_vocab = {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
|
|
80
|
-
else:
|
|
81
|
-
_vocab = {
|
|
69
|
+
_vocab = (
|
|
70
|
+
vocab
|
|
71
|
+
if vocab is not None
|
|
72
|
+
else {
|
|
82
73
|
str(bos_token): 0,
|
|
83
74
|
str(eos_token): 1,
|
|
84
75
|
str(pad_token): 2,
|
|
85
76
|
}
|
|
77
|
+
)
|
|
86
78
|
|
|
87
|
-
|
|
88
|
-
_merges = [tuple(merge) if isinstance(merge, list) else merge for merge in merges]
|
|
89
|
-
else:
|
|
90
|
-
_merges = []
|
|
79
|
+
self._merges = merges or []
|
|
91
80
|
|
|
92
81
|
self._tokenizer = Tokenizer(
|
|
93
82
|
BPE(
|
|
94
83
|
vocab=_vocab,
|
|
95
|
-
merges=_merges,
|
|
84
|
+
merges=self._merges,
|
|
96
85
|
dropout=None,
|
|
97
86
|
continuing_subword_prefix="",
|
|
98
87
|
end_of_word_suffix="</w>",
|
|
@@ -120,20 +109,7 @@ class CLIPTokenizer(TokenizersBackend):
|
|
|
120
109
|
|
|
121
110
|
self._tokenizer.decoder = decoders.ByteLevel()
|
|
122
111
|
|
|
123
|
-
bos_token_id = _vocab.get(str(bos_token), 0)
|
|
124
|
-
eos_token_id = _vocab.get(str(eos_token), 1)
|
|
125
|
-
|
|
126
|
-
self._tokenizer.post_processor = processors.RobertaProcessing(
|
|
127
|
-
sep=(str(eos_token), eos_token_id),
|
|
128
|
-
cls=(str(bos_token), bos_token_id),
|
|
129
|
-
add_prefix_space=False,
|
|
130
|
-
trim_offsets=False,
|
|
131
|
-
)
|
|
132
|
-
|
|
133
|
-
tokenizer_object = self._tokenizer
|
|
134
|
-
|
|
135
112
|
super().__init__(
|
|
136
|
-
tokenizer_object=tokenizer_object,
|
|
137
113
|
unk_token=unk_token,
|
|
138
114
|
bos_token=bos_token,
|
|
139
115
|
eos_token=eos_token,
|
|
@@ -141,14 +117,16 @@ class CLIPTokenizer(TokenizersBackend):
|
|
|
141
117
|
**kwargs,
|
|
142
118
|
)
|
|
143
119
|
|
|
144
|
-
|
|
145
|
-
self.
|
|
120
|
+
self._tokenizer.post_processor = processors.RobertaProcessing(
|
|
121
|
+
sep=(str(eos_token), self.eos_token_id),
|
|
122
|
+
cls=(str(bos_token), self.bos_token_id),
|
|
123
|
+
add_prefix_space=False,
|
|
124
|
+
trim_offsets=False,
|
|
125
|
+
)
|
|
146
126
|
|
|
147
|
-
|
|
148
|
-
super()._post_init()
|
|
127
|
+
# Very ugly hack to enable padding to have a correct decoding see https://github.com/huggingface/tokenizers/issues/872
|
|
149
128
|
self._wrap_decode_method_backend_tokenizer()
|
|
150
129
|
|
|
151
|
-
# Very ugly hack to enable padding to have a correct decoding see https://github.com/huggingface/tokenizers/issues/872
|
|
152
130
|
def _wrap_decode_method_backend_tokenizer(self):
|
|
153
131
|
orig_decode_method = self.backend_tokenizer.decode
|
|
154
132
|
|
|
@@ -676,6 +676,7 @@ class CLIPSegTextModel(CLIPSegPreTrainedModel):
|
|
|
676
676
|
output_attentions: Optional[bool] = None,
|
|
677
677
|
output_hidden_states: Optional[bool] = None,
|
|
678
678
|
return_dict: Optional[bool] = None,
|
|
679
|
+
**kwargs,
|
|
679
680
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
|
680
681
|
r"""
|
|
681
682
|
Examples:
|
|
@@ -776,6 +777,7 @@ class CLIPSegVisionModel(CLIPSegPreTrainedModel):
|
|
|
776
777
|
output_hidden_states: Optional[bool] = None,
|
|
777
778
|
interpolate_pos_encoding: Optional[bool] = True,
|
|
778
779
|
return_dict: Optional[bool] = None,
|
|
780
|
+
**kwargs,
|
|
779
781
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
|
780
782
|
r"""
|
|
781
783
|
Examples:
|
|
@@ -933,6 +935,7 @@ class CLIPSegModel(CLIPSegPreTrainedModel):
|
|
|
933
935
|
output_hidden_states: Optional[bool] = None,
|
|
934
936
|
interpolate_pos_encoding: bool = True,
|
|
935
937
|
return_dict: Optional[bool] = None,
|
|
938
|
+
**kwargs,
|
|
936
939
|
) -> Union[tuple, CLIPSegOutput]:
|
|
937
940
|
r"""
|
|
938
941
|
return_loss (`bool`, *optional*):
|
|
@@ -1125,6 +1128,7 @@ class CLIPSegDecoder(CLIPSegPreTrainedModel):
|
|
|
1125
1128
|
output_attentions: Optional[bool] = None,
|
|
1126
1129
|
output_hidden_states: Optional[bool] = None,
|
|
1127
1130
|
return_dict: Optional[bool] = True,
|
|
1131
|
+
**kwargs,
|
|
1128
1132
|
):
|
|
1129
1133
|
all_hidden_states = () if output_hidden_states else None
|
|
1130
1134
|
all_attentions = () if output_attentions else None
|
|
@@ -1239,6 +1243,7 @@ class CLIPSegForImageSegmentation(CLIPSegPreTrainedModel):
|
|
|
1239
1243
|
output_hidden_states: Optional[bool] = None,
|
|
1240
1244
|
interpolate_pos_encoding: bool = True,
|
|
1241
1245
|
return_dict: Optional[bool] = None,
|
|
1246
|
+
**kwargs,
|
|
1242
1247
|
) -> Union[tuple, CLIPSegOutput]:
|
|
1243
1248
|
r"""
|
|
1244
1249
|
conditional_pixel_values (`torch.FloatTensor`, *optional*):
|
|
@@ -861,6 +861,7 @@ class ClvpEncoder(ClvpPreTrainedModel):
|
|
|
861
861
|
output_attentions: Optional[bool] = None,
|
|
862
862
|
output_hidden_states: Optional[bool] = None,
|
|
863
863
|
return_dict: Optional[bool] = None,
|
|
864
|
+
**kwargs,
|
|
864
865
|
) -> Union[tuple, BaseModelOutput]:
|
|
865
866
|
r"""
|
|
866
867
|
Args:
|
|
@@ -1020,6 +1021,7 @@ class ClvpDecoder(ClvpPreTrainedModel):
|
|
|
1020
1021
|
output_hidden_states: Optional[bool] = None,
|
|
1021
1022
|
return_dict: Optional[bool] = None,
|
|
1022
1023
|
cache_position: Optional[torch.Tensor] = None,
|
|
1024
|
+
**kwargs,
|
|
1023
1025
|
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
1024
1026
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1025
1027
|
output_hidden_states = (
|
|
@@ -1170,6 +1172,7 @@ class ClvpModel(ClvpPreTrainedModel):
|
|
|
1170
1172
|
output_hidden_states: Optional[bool] = None,
|
|
1171
1173
|
return_dict: Optional[bool] = None,
|
|
1172
1174
|
cache_position: Optional[torch.Tensor] = None,
|
|
1175
|
+
**kwargs,
|
|
1173
1176
|
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
1174
1177
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1175
1178
|
output_hidden_states = (
|
|
@@ -1339,6 +1342,7 @@ class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
|
|
|
1339
1342
|
output_hidden_states: Optional[bool] = None,
|
|
1340
1343
|
return_dict: Optional[bool] = None,
|
|
1341
1344
|
cache_position: Optional[torch.Tensor] = None,
|
|
1345
|
+
**kwargs,
|
|
1342
1346
|
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
|
|
1343
1347
|
r"""
|
|
1344
1348
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1635,6 +1639,7 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel, GenerationMixin):
|
|
|
1635
1639
|
output_attentions: Optional[bool] = False,
|
|
1636
1640
|
return_dict: Optional[bool] = None,
|
|
1637
1641
|
cache_position: Optional[torch.Tensor] = None,
|
|
1642
|
+
**kwargs,
|
|
1638
1643
|
) -> Union[tuple, ClvpOutput]:
|
|
1639
1644
|
r"""
|
|
1640
1645
|
conditioning_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
|
|
@@ -15,9 +15,7 @@
|
|
|
15
15
|
"""Tokenization class for CLVP."""
|
|
16
16
|
|
|
17
17
|
import json
|
|
18
|
-
import os
|
|
19
18
|
from functools import lru_cache
|
|
20
|
-
from typing import Optional
|
|
21
19
|
|
|
22
20
|
import regex as re
|
|
23
21
|
|
|
@@ -123,10 +121,6 @@ class ClvpTokenizer(PreTrainedTokenizer):
|
|
|
123
121
|
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
|
124
122
|
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
|
125
123
|
other word. (CLVP tokenizer detect beginning of words by the preceding space).
|
|
126
|
-
add_bos_token (`bool`, *optional*, defaults to `False`):
|
|
127
|
-
Whether to add `bos_token` in front of the sequence when add_special_tokens=True.
|
|
128
|
-
add_eos_token (`bool`, *optional*, defaults to `False`):
|
|
129
|
-
Whether to add `eos_token` in end of the sequence when add_special_tokens=True.
|
|
130
124
|
"""
|
|
131
125
|
|
|
132
126
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
@@ -145,8 +139,6 @@ class ClvpTokenizer(PreTrainedTokenizer):
|
|
|
145
139
|
eos_token="[STOP]",
|
|
146
140
|
pad_token="[STOP]",
|
|
147
141
|
add_prefix_space=False,
|
|
148
|
-
add_bos_token=False,
|
|
149
|
-
add_eos_token=False,
|
|
150
142
|
**kwargs,
|
|
151
143
|
):
|
|
152
144
|
bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token
|
|
@@ -154,20 +146,7 @@ class ClvpTokenizer(PreTrainedTokenizer):
|
|
|
154
146
|
unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token
|
|
155
147
|
pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
|
|
156
148
|
|
|
157
|
-
self.add_bos_token = add_bos_token
|
|
158
|
-
self.add_eos_token = add_eos_token
|
|
159
149
|
self._normalizer = None
|
|
160
|
-
|
|
161
|
-
# Set special_tokens_pattern based on add_bos_token and add_eos_token flags
|
|
162
|
-
if add_bos_token and add_eos_token:
|
|
163
|
-
kwargs["special_tokens_pattern"] = "bos_eos"
|
|
164
|
-
elif add_bos_token:
|
|
165
|
-
kwargs["special_tokens_pattern"] = "bos"
|
|
166
|
-
elif add_eos_token:
|
|
167
|
-
kwargs["special_tokens_pattern"] = "eos"
|
|
168
|
-
else:
|
|
169
|
-
kwargs["special_tokens_pattern"] = "none"
|
|
170
|
-
|
|
171
150
|
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
|
172
151
|
self.encoder = json.load(vocab_handle)
|
|
173
152
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
|
@@ -191,8 +170,7 @@ class ClvpTokenizer(PreTrainedTokenizer):
|
|
|
191
170
|
eos_token=eos_token,
|
|
192
171
|
pad_token=pad_token,
|
|
193
172
|
add_prefix_space=add_prefix_space,
|
|
194
|
-
|
|
195
|
-
add_eos_token=add_eos_token,
|
|
173
|
+
special_tokens_pattern="none",
|
|
196
174
|
**kwargs,
|
|
197
175
|
)
|
|
198
176
|
|
|
@@ -251,17 +229,6 @@ class ClvpTokenizer(PreTrainedTokenizer):
|
|
|
251
229
|
self.cache[token] = word
|
|
252
230
|
return word
|
|
253
231
|
|
|
254
|
-
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
|
255
|
-
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
|
256
|
-
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
|
257
|
-
|
|
258
|
-
output = bos_token_id + token_ids_0 + eos_token_id
|
|
259
|
-
|
|
260
|
-
if token_ids_1 is not None:
|
|
261
|
-
output = output + bos_token_id + token_ids_1 + eos_token_id
|
|
262
|
-
|
|
263
|
-
return output
|
|
264
|
-
|
|
265
232
|
def _tokenize(self, text):
|
|
266
233
|
"""Tokenize a string."""
|
|
267
234
|
bpe_tokens = []
|
|
@@ -303,34 +270,5 @@ class ClvpTokenizer(PreTrainedTokenizer):
|
|
|
303
270
|
text = text.replace(self.unk_token, "").replace(" ", " ").replace(" ", " ")
|
|
304
271
|
return text
|
|
305
272
|
|
|
306
|
-
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
|
307
|
-
if not os.path.isdir(save_directory):
|
|
308
|
-
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
|
309
|
-
return
|
|
310
|
-
vocab_file = os.path.join(
|
|
311
|
-
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
|
312
|
-
)
|
|
313
|
-
merge_file = os.path.join(
|
|
314
|
-
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
|
315
|
-
)
|
|
316
|
-
|
|
317
|
-
with open(vocab_file, "w", encoding="utf-8") as f:
|
|
318
|
-
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
|
319
|
-
|
|
320
|
-
index = 0
|
|
321
|
-
with open(merge_file, "w", encoding="utf-8") as writer:
|
|
322
|
-
writer.write("#version: 0.2\n")
|
|
323
|
-
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
|
324
|
-
if index != token_index:
|
|
325
|
-
logger.warning(
|
|
326
|
-
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
|
327
|
-
" Please check that the tokenizer is not corrupted!"
|
|
328
|
-
)
|
|
329
|
-
index = token_index
|
|
330
|
-
writer.write(" ".join(bpe_tokens) + "\n")
|
|
331
|
-
index += 1
|
|
332
|
-
|
|
333
|
-
return vocab_file, merge_file
|
|
334
|
-
|
|
335
273
|
|
|
336
274
|
__all__ = ["ClvpTokenizer"]
|
|
@@ -14,10 +14,11 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
from
|
|
17
|
+
from typing import Optional, Union
|
|
18
|
+
|
|
19
|
+
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
|
18
20
|
from tokenizers.models import BPE
|
|
19
21
|
|
|
20
|
-
from ...tokenization_utils_base import _get_prepend_scheme, generate_merges
|
|
21
22
|
from ...tokenization_utils_tokenizers import TokenizersBackend
|
|
22
23
|
from ...utils import logging
|
|
23
24
|
|
|
@@ -97,9 +98,9 @@ class CodeLlamaTokenizer(TokenizersBackend):
|
|
|
97
98
|
add_prefix_space (`bool`, *optional*):
|
|
98
99
|
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
|
99
100
|
other word.
|
|
100
|
-
vocab (`dict`, *optional*):
|
|
101
|
+
vocab (`str`, `dict` or `list`, *optional*):
|
|
101
102
|
Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
|
|
102
|
-
merges (`list`, *optional*):
|
|
103
|
+
merges (`str` or `list`, *optional*):
|
|
103
104
|
Custom merges list. If not provided, merges are loaded from merges_file.
|
|
104
105
|
vocab_file (`str`, *optional*):
|
|
105
106
|
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
|
|
@@ -109,9 +110,12 @@ class CodeLlamaTokenizer(TokenizersBackend):
|
|
|
109
110
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
110
111
|
padding_side = "left"
|
|
111
112
|
model_input_names = ["input_ids", "attention_mask"]
|
|
113
|
+
model = BPE
|
|
112
114
|
|
|
113
115
|
def __init__(
|
|
114
116
|
self,
|
|
117
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
118
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
115
119
|
clean_up_tokenization_spaces=False,
|
|
116
120
|
unk_token="<unk>",
|
|
117
121
|
bos_token="<s>",
|
|
@@ -122,37 +126,28 @@ class CodeLlamaTokenizer(TokenizersBackend):
|
|
|
122
126
|
eot_token="▁<EOT>",
|
|
123
127
|
fill_token="<FILL_ME>",
|
|
124
128
|
additional_special_tokens=None,
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
add_prefix_space=None,
|
|
129
|
-
vocab=None,
|
|
130
|
-
merges=None,
|
|
131
|
-
vocab_file=None,
|
|
129
|
+
use_default_system_prompt: bool = False,
|
|
130
|
+
add_prefix_space: Optional[bool] = True,
|
|
131
|
+
add_bos_token: bool = True,
|
|
132
132
|
**kwargs,
|
|
133
133
|
):
|
|
134
134
|
self.add_prefix_space = add_prefix_space if add_prefix_space is not None else True
|
|
135
135
|
self.use_default_system_prompt = use_default_system_prompt
|
|
136
|
-
|
|
137
136
|
additional_special_tokens = additional_special_tokens or []
|
|
138
137
|
for token in [prefix_token, middle_token, suffix_token, eot_token, fill_token]:
|
|
139
138
|
additional_special_tokens += [token] if token is not None else []
|
|
140
139
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
else:
|
|
146
|
-
self._vocab = {
|
|
140
|
+
self._vocab = (
|
|
141
|
+
vocab
|
|
142
|
+
if vocab is not None
|
|
143
|
+
else {
|
|
147
144
|
str(unk_token): 0,
|
|
148
145
|
str(bos_token): 1,
|
|
149
146
|
str(eos_token): 2,
|
|
150
147
|
}
|
|
148
|
+
)
|
|
151
149
|
|
|
152
|
-
|
|
153
|
-
t: i for t, i in self._vocab.items() if t not in {str(eos_token), str(bos_token), str(unk_token)}
|
|
154
|
-
}
|
|
155
|
-
self._merges = merges if merges is not None else generate_merges(filtered_vocab)
|
|
150
|
+
self._merges = merges or []
|
|
156
151
|
self._tokenizer = Tokenizer(
|
|
157
152
|
BPE(
|
|
158
153
|
vocab=self._vocab,
|
|
@@ -163,8 +158,9 @@ class CodeLlamaTokenizer(TokenizersBackend):
|
|
|
163
158
|
unk_token=str(unk_token),
|
|
164
159
|
)
|
|
165
160
|
)
|
|
161
|
+
prepend_scheme = "first" if self.add_prefix_space else "none"
|
|
166
162
|
self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
|
|
167
|
-
replacement="▁", prepend_scheme=
|
|
163
|
+
replacement="▁", prepend_scheme=prepend_scheme, split=False
|
|
168
164
|
)
|
|
169
165
|
|
|
170
166
|
self._tokenizer.decoder = decoders.Sequence(
|
|
@@ -172,13 +168,10 @@ class CodeLlamaTokenizer(TokenizersBackend):
|
|
|
172
168
|
)
|
|
173
169
|
|
|
174
170
|
super().__init__(
|
|
175
|
-
tokenizer_object=self._tokenizer,
|
|
176
171
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
|
177
172
|
unk_token=unk_token,
|
|
178
173
|
bos_token=bos_token,
|
|
179
174
|
eos_token=eos_token,
|
|
180
|
-
add_bos_token=add_bos_token,
|
|
181
|
-
add_eos_token=add_eos_token,
|
|
182
175
|
use_default_system_prompt=use_default_system_prompt,
|
|
183
176
|
add_prefix_space=add_prefix_space,
|
|
184
177
|
prefix_token=prefix_token,
|
|
@@ -186,32 +179,16 @@ class CodeLlamaTokenizer(TokenizersBackend):
|
|
|
186
179
|
suffix_token=suffix_token,
|
|
187
180
|
eot_token=eot_token,
|
|
188
181
|
fill_token=fill_token,
|
|
182
|
+
add_bos_token=add_bos_token,
|
|
189
183
|
additional_special_tokens=additional_special_tokens,
|
|
190
184
|
**kwargs,
|
|
191
185
|
)
|
|
192
|
-
|
|
193
|
-
self._add_bos_token = add_bos_token
|
|
194
|
-
self._add_eos_token = add_eos_token
|
|
195
|
-
self.vocab_file = vocab_file
|
|
196
|
-
|
|
197
186
|
self._prefix_token = prefix_token
|
|
198
187
|
self._middle_token = middle_token
|
|
199
188
|
self._suffix_token = suffix_token
|
|
200
189
|
self._eot_token = eot_token
|
|
201
190
|
self.fill_token = fill_token
|
|
202
191
|
|
|
203
|
-
self._post_init()
|
|
204
|
-
|
|
205
|
-
def _post_init(self):
|
|
206
|
-
self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="first", split=False)
|
|
207
|
-
self._tokenizer.normalizer = None
|
|
208
|
-
|
|
209
|
-
# This matches LlamaTokenizer's behavior and is needed when loading from vocab/merges
|
|
210
|
-
self.add_tokens([AddedToken(token, special=True) for token in self.all_special_tokens])
|
|
211
|
-
|
|
212
|
-
self.update_post_processor()
|
|
213
|
-
super()._post_init()
|
|
214
|
-
|
|
215
192
|
@property
|
|
216
193
|
def prefix_token(self):
|
|
217
194
|
return self._prefix_token
|
|
@@ -67,6 +67,10 @@ class CodeGenTokenizer(TokenizersBackend):
|
|
|
67
67
|
refer to this superclass for more information regarding those methods.
|
|
68
68
|
|
|
69
69
|
Args:
|
|
70
|
+
vocab (`str` or `dict[str, int]`, *optional*):
|
|
71
|
+
Custom vocabulary dictionary. If not provided, vocabulary is loaded from `vocab_file`.
|
|
72
|
+
merges (`str` or `list[str]`, *optional*):
|
|
73
|
+
Custom merges list. If not provided, merges are loaded from `merges_file`.
|
|
70
74
|
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
|
71
75
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
|
72
76
|
token instead.
|
|
@@ -79,31 +83,24 @@ class CodeGenTokenizer(TokenizersBackend):
|
|
|
79
83
|
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
|
80
84
|
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
|
81
85
|
other word. (CodeGen tokenizer detect beginning of words by the preceding space).
|
|
82
|
-
add_bos_token (`bool`, *optional*, defaults to `False`):
|
|
83
|
-
Whether or not to add an initial beginning of sentence token to the input.
|
|
84
86
|
return_token_type_ids (`bool`, *optional*, defaults to `False`):
|
|
85
87
|
Whether to return token type IDs.
|
|
86
|
-
vocab (`dict`, *optional*):
|
|
87
|
-
Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
|
|
88
|
-
merges (`list`, *optional*):
|
|
89
|
-
Custom merges list. If not provided, merges are loaded from merges_file.
|
|
90
88
|
"""
|
|
91
89
|
|
|
92
90
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
93
91
|
model_input_names = ["input_ids", "attention_mask"]
|
|
94
|
-
|
|
92
|
+
model = BPE
|
|
95
93
|
|
|
96
94
|
def __init__(
|
|
97
95
|
self,
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
96
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
97
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
98
|
+
unk_token: str = "<|endoftext|>",
|
|
99
|
+
bos_token: str = "<|endoftext|>",
|
|
100
|
+
eos_token: str = "<|endoftext|>",
|
|
101
101
|
pad_token=None,
|
|
102
|
-
add_prefix_space=False,
|
|
103
|
-
|
|
104
|
-
return_token_type_ids=False,
|
|
105
|
-
vocab: Optional[dict] = None,
|
|
106
|
-
merges: Optional[list] = None,
|
|
102
|
+
add_prefix_space: bool = False,
|
|
103
|
+
return_token_type_ids: bool = False,
|
|
107
104
|
**kwargs,
|
|
108
105
|
):
|
|
109
106
|
self.return_token_type_ids = return_token_type_ids
|
|
@@ -112,17 +109,8 @@ class CodeGenTokenizer(TokenizersBackend):
|
|
|
112
109
|
|
|
113
110
|
self.add_prefix_space = add_prefix_space
|
|
114
111
|
|
|
115
|
-
if vocab is not None
|
|
116
|
-
|
|
117
|
-
{token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
|
|
118
|
-
)
|
|
119
|
-
else:
|
|
120
|
-
self._vocab = {}
|
|
121
|
-
|
|
122
|
-
if merges is not None:
|
|
123
|
-
self._merges = merges
|
|
124
|
-
else:
|
|
125
|
-
self._merges = []
|
|
112
|
+
self._vocab = vocab if vocab is not None else {}
|
|
113
|
+
self._merges = merges or []
|
|
126
114
|
|
|
127
115
|
self._tokenizer = Tokenizer(
|
|
128
116
|
BPE(
|
|
@@ -141,33 +129,16 @@ class CodeGenTokenizer(TokenizersBackend):
|
|
|
141
129
|
add_prefix_space=True, use_regex=True, trim_offsets=False
|
|
142
130
|
)
|
|
143
131
|
|
|
144
|
-
tokenizer_object = self._tokenizer
|
|
145
|
-
|
|
146
|
-
# Set these before calling super().__init__() so the base class _post_init() can use them
|
|
147
|
-
self._add_bos_token = add_bos_token
|
|
148
|
-
self._add_eos_token = False
|
|
149
|
-
|
|
150
132
|
super().__init__(
|
|
151
|
-
tokenizer_object=tokenizer_object,
|
|
152
133
|
unk_token=unk_token,
|
|
153
134
|
bos_token=bos_token,
|
|
154
135
|
eos_token=eos_token,
|
|
155
136
|
pad_token=pad_token,
|
|
156
137
|
add_prefix_space=add_prefix_space,
|
|
157
|
-
add_bos_token=add_bos_token,
|
|
158
138
|
return_token_type_ids=return_token_type_ids,
|
|
159
139
|
**kwargs,
|
|
160
140
|
)
|
|
161
141
|
|
|
162
|
-
self._post_init()
|
|
163
|
-
|
|
164
|
-
def _post_init(self):
|
|
165
|
-
self._tokenizer.post_processor = processors.ByteLevel(
|
|
166
|
-
add_prefix_space=True, use_regex=True, trim_offsets=False
|
|
167
|
-
)
|
|
168
|
-
# Ensure base class post-init runs to register special/extra tokens, etc.
|
|
169
|
-
super()._post_init()
|
|
170
|
-
|
|
171
142
|
def decode(
|
|
172
143
|
self,
|
|
173
144
|
token_ids: Union[int, list[int], np.ndarray, "torch.Tensor"],
|
|
@@ -36,6 +36,7 @@ from torch import nn
|
|
|
36
36
|
from ...activations import ACT2FN
|
|
37
37
|
from ...cache_utils import Cache, DynamicCache
|
|
38
38
|
from ...generation import GenerationMixin
|
|
39
|
+
from ...integrations import use_kernelized_func
|
|
39
40
|
from ...masking_utils import create_causal_mask
|
|
40
41
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
41
42
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -44,7 +45,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
44
45
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
45
46
|
from ...processing_utils import Unpack
|
|
46
47
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
47
|
-
from ...utils.generic import check_model_inputs
|
|
48
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
48
49
|
from .configuration_cohere import CohereConfig
|
|
49
50
|
|
|
50
51
|
|
|
@@ -121,7 +122,7 @@ class CohereRotaryEmbedding(nn.Module):
|
|
|
121
122
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
122
123
|
|
|
123
124
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
124
|
-
with
|
|
125
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
125
126
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
126
127
|
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
|
|
127
128
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -222,6 +223,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
222
223
|
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
|
|
223
224
|
|
|
224
225
|
|
|
226
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
225
227
|
class CohereAttention(nn.Module):
|
|
226
228
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
227
229
|
|
|
@@ -247,7 +249,6 @@ class CohereAttention(nn.Module):
|
|
|
247
249
|
self.o_proj = nn.Linear(
|
|
248
250
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
249
251
|
)
|
|
250
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
251
252
|
self.use_qk_norm = config.use_qk_norm
|
|
252
253
|
if self.use_qk_norm:
|
|
253
254
|
# When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
|
|
@@ -36,6 +36,7 @@ from ...modeling_rope_utils import dynamic_rope_update
|
|
|
36
36
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
37
37
|
from ...processing_utils import Unpack
|
|
38
38
|
from ...utils import TransformersKwargs, logging
|
|
39
|
+
from ...utils.generic import maybe_autocast
|
|
39
40
|
from ..llama.modeling_llama import (
|
|
40
41
|
LlamaAttention,
|
|
41
42
|
LlamaForCausalLM,
|
|
@@ -75,7 +76,7 @@ class CohereRotaryEmbedding(LlamaRotaryEmbedding):
|
|
|
75
76
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
76
77
|
|
|
77
78
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
78
|
-
with
|
|
79
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
79
80
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
80
81
|
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
|
|
81
82
|
cos = emb.cos() * self.attention_scaling
|