transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,7 @@ from collections.abc import Mapping
|
|
|
20
20
|
from typing import Optional, Union
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
|
-
from tokenizers import Tokenizer, decoders, pre_tokenizers
|
|
23
|
+
from tokenizers import Tokenizer, decoders, pre_tokenizers
|
|
24
24
|
from tokenizers.models import BPE
|
|
25
25
|
|
|
26
26
|
from ...tokenization_python import PreTrainedTokenizer
|
|
@@ -167,6 +167,10 @@ class LukeTokenizer(TokenizersBackend):
|
|
|
167
167
|
Path to the vocabulary file.
|
|
168
168
|
merges_file (`str`):
|
|
169
169
|
Path to the merges file.
|
|
170
|
+
vocab (`str` or `dict[str, int]`, *optional*):
|
|
171
|
+
Custom vocabulary dictionary. If not provided, the vocabulary is loaded from `vocab_file`.
|
|
172
|
+
merges (`str` or `list[str]`, *optional*):
|
|
173
|
+
Custom merges list. If not provided, merges are loaded from `merges_file`.
|
|
170
174
|
entity_vocab_file (`str`):
|
|
171
175
|
Path to the entity vocabulary file.
|
|
172
176
|
task (`str`, *optional*):
|
|
@@ -228,10 +232,13 @@ class LukeTokenizer(TokenizersBackend):
|
|
|
228
232
|
|
|
229
233
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
230
234
|
model_input_names = ["input_ids", "attention_mask"]
|
|
231
|
-
|
|
235
|
+
model = BPE
|
|
232
236
|
|
|
233
237
|
def __init__(
|
|
234
238
|
self,
|
|
239
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
240
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
241
|
+
entity_vocab: Optional[Union[str, dict, list]] = None,
|
|
235
242
|
errors="replace",
|
|
236
243
|
bos_token="<s>",
|
|
237
244
|
eos_token="</s>",
|
|
@@ -250,37 +257,17 @@ class LukeTokenizer(TokenizersBackend):
|
|
|
250
257
|
entity_pad_token="[PAD]",
|
|
251
258
|
entity_mask_token="[MASK]",
|
|
252
259
|
entity_mask2_token="[MASK2]",
|
|
253
|
-
vocab: Optional[dict] = None,
|
|
254
|
-
merges: Optional[list] = None,
|
|
255
|
-
entity_vocab: Optional[dict] = None,
|
|
256
260
|
**kwargs,
|
|
257
261
|
):
|
|
258
262
|
self.add_prefix_space = add_prefix_space
|
|
259
263
|
|
|
260
264
|
# Handle entity vocab file for backward compatibility
|
|
261
265
|
entity_vocab_file = kwargs.pop("entity_vocab_file", None)
|
|
262
|
-
|
|
263
|
-
# Check if vocab/merges/entity_vocab are in kwargs
|
|
264
|
-
if vocab is None and "vocab" in kwargs:
|
|
265
|
-
vocab = kwargs.pop("vocab")
|
|
266
|
-
if merges is None and "merges" in kwargs:
|
|
267
|
-
merges = kwargs.pop("merges")
|
|
268
266
|
if entity_vocab is None and "entity_vocab" in kwargs:
|
|
269
267
|
entity_vocab = kwargs.pop("entity_vocab")
|
|
270
268
|
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
self._vocab = (
|
|
274
|
-
{token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
|
|
275
|
-
)
|
|
276
|
-
else:
|
|
277
|
-
self._vocab = {}
|
|
278
|
-
|
|
279
|
-
if merges is not None:
|
|
280
|
-
self._merges = [tuple(merge) if isinstance(merge, list) else merge for merge in merges]
|
|
281
|
-
else:
|
|
282
|
-
self._merges = []
|
|
283
|
-
|
|
269
|
+
self._vocab = vocab or {}
|
|
270
|
+
self._merges = merges or []
|
|
284
271
|
self._tokenizer = Tokenizer(
|
|
285
272
|
BPE(
|
|
286
273
|
vocab=self._vocab,
|
|
@@ -365,8 +352,6 @@ class LukeTokenizer(TokenizersBackend):
|
|
|
365
352
|
|
|
366
353
|
kwargs["extra_special_tokens"] = extra_tokens
|
|
367
354
|
|
|
368
|
-
tokenizer_object = self._tokenizer
|
|
369
|
-
|
|
370
355
|
# Configure default special token behaviors to match LUKE formatting
|
|
371
356
|
token_type_ids_pattern = kwargs.setdefault("token_type_ids_pattern", "all_zeros")
|
|
372
357
|
special_tokens_pattern = kwargs.setdefault("special_tokens_pattern", "cls_double_sep")
|
|
@@ -379,7 +364,6 @@ class LukeTokenizer(TokenizersBackend):
|
|
|
379
364
|
kwargs.setdefault("clean_up_tokenization_spaces", True)
|
|
380
365
|
|
|
381
366
|
super().__init__(
|
|
382
|
-
tokenizer_object=tokenizer_object,
|
|
383
367
|
errors=errors,
|
|
384
368
|
bos_token=bos_token,
|
|
385
369
|
eos_token=eos_token,
|
|
@@ -401,17 +385,6 @@ class LukeTokenizer(TokenizersBackend):
|
|
|
401
385
|
entity_vocab=entity_vocab if entity_vocab_file is None else None, # Only store if it was passed as data
|
|
402
386
|
**kwargs,
|
|
403
387
|
)
|
|
404
|
-
self._post_init()
|
|
405
|
-
|
|
406
|
-
def _post_init(self):
|
|
407
|
-
self._tokenizer.post_processor = processors.TemplateProcessing(
|
|
408
|
-
single=f"{self.cls_token}:0 $A:0 {self.sep_token}:0",
|
|
409
|
-
pair=f"{self.cls_token}:0 $A:0 {self.sep_token}:0 {self.sep_token}:0 $B:1 {self.sep_token}:1",
|
|
410
|
-
special_tokens=[
|
|
411
|
-
(self.cls_token, self.cls_token_id),
|
|
412
|
-
(self.sep_token, self.sep_token_id),
|
|
413
|
-
],
|
|
414
|
-
)
|
|
415
388
|
|
|
416
389
|
def build_inputs_with_special_tokens(
|
|
417
390
|
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
|
@@ -711,6 +711,7 @@ class LxmertModel(LxmertPreTrainedModel):
|
|
|
711
711
|
output_attentions: Optional[bool] = None,
|
|
712
712
|
output_hidden_states: Optional[bool] = None,
|
|
713
713
|
return_dict: Optional[bool] = None,
|
|
714
|
+
**kwargs,
|
|
714
715
|
) -> Union[LxmertModelOutput, tuple[torch.FloatTensor]]:
|
|
715
716
|
r"""
|
|
716
717
|
visual_feats (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`):
|
|
@@ -1244,6 +1245,7 @@ class LxmertForQuestionAnswering(LxmertPreTrainedModel):
|
|
|
1244
1245
|
output_attentions: Optional[bool] = None,
|
|
1245
1246
|
output_hidden_states: Optional[bool] = None,
|
|
1246
1247
|
return_dict: Optional[bool] = None,
|
|
1248
|
+
**kwargs,
|
|
1247
1249
|
) -> Union[LxmertForQuestionAnsweringOutput, tuple[torch.FloatTensor]]:
|
|
1248
1250
|
r"""
|
|
1249
1251
|
visual_feats (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`):
|
|
@@ -561,6 +561,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
|
|
|
561
561
|
output_attentions: Optional[bool] = None,
|
|
562
562
|
output_hidden_states: Optional[bool] = None,
|
|
563
563
|
return_dict: Optional[bool] = None,
|
|
564
|
+
**kwargs,
|
|
564
565
|
):
|
|
565
566
|
r"""
|
|
566
567
|
Args:
|
|
@@ -713,6 +714,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
|
|
713
714
|
output_hidden_states: Optional[bool] = None,
|
|
714
715
|
return_dict: Optional[bool] = None,
|
|
715
716
|
cache_position: Optional[torch.Tensor] = None,
|
|
717
|
+
**kwargs,
|
|
716
718
|
):
|
|
717
719
|
r"""
|
|
718
720
|
Args:
|
|
@@ -941,6 +943,7 @@ class M2M100Model(M2M100PreTrainedModel):
|
|
|
941
943
|
output_hidden_states: Optional[bool] = None,
|
|
942
944
|
return_dict: Optional[bool] = None,
|
|
943
945
|
cache_position: Optional[torch.Tensor] = None,
|
|
946
|
+
**kwargs,
|
|
944
947
|
) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
|
|
945
948
|
r"""
|
|
946
949
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1046,6 +1049,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel, GenerationMixin):
|
|
|
1046
1049
|
output_hidden_states: Optional[bool] = None,
|
|
1047
1050
|
return_dict: Optional[bool] = None,
|
|
1048
1051
|
cache_position: Optional[torch.Tensor] = None,
|
|
1052
|
+
**kwargs,
|
|
1049
1053
|
) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]:
|
|
1050
1054
|
r"""
|
|
1051
1055
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -34,7 +34,7 @@ from ...utils import (
|
|
|
34
34
|
auto_docstring,
|
|
35
35
|
logging,
|
|
36
36
|
)
|
|
37
|
-
from ...utils.import_utils import
|
|
37
|
+
from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling
|
|
38
38
|
from .configuration_mamba import MambaConfig
|
|
39
39
|
|
|
40
40
|
|
|
@@ -45,12 +45,6 @@ if is_mambapy_available():
|
|
|
45
45
|
else:
|
|
46
46
|
pscan = None
|
|
47
47
|
|
|
48
|
-
if is_mamba_ssm_available():
|
|
49
|
-
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
|
|
50
|
-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
|
51
|
-
else:
|
|
52
|
-
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
|
53
|
-
|
|
54
48
|
|
|
55
49
|
class MambaCache:
|
|
56
50
|
"""
|
|
@@ -204,15 +198,24 @@ class MambaMixer(nn.Module):
|
|
|
204
198
|
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
|
|
205
199
|
self.use_bias = config.use_bias
|
|
206
200
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
def warn_slow_implementation(self):
|
|
201
|
+
global causal_conv1d, causal_conv1d_update, causal_conv1d_fn
|
|
210
202
|
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
211
203
|
causal_conv1d_update, causal_conv1d_fn = (
|
|
212
204
|
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
|
213
205
|
if causal_conv1d is not None
|
|
214
206
|
else (None, None)
|
|
215
207
|
)
|
|
208
|
+
global mamba_ssm, selective_state_update, selective_scan_fn, mamba_inner_fn
|
|
209
|
+
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
|
210
|
+
selective_state_update, selective_scan_fn, mamba_inner_fn = (
|
|
211
|
+
(mamba_ssm.selective_state_update, mamba_ssm.selective_scan_fn, mamba_ssm.mamba_inner_fn)
|
|
212
|
+
if mamba_ssm is not None
|
|
213
|
+
else (None, None, None)
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
self.warn_slow_implementation()
|
|
217
|
+
|
|
218
|
+
def warn_slow_implementation(self):
|
|
216
219
|
is_fast_path_available = all(
|
|
217
220
|
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
|
218
221
|
)
|
|
@@ -263,12 +266,6 @@ class MambaMixer(nn.Module):
|
|
|
263
266
|
)
|
|
264
267
|
|
|
265
268
|
else:
|
|
266
|
-
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
267
|
-
causal_conv1d_update, causal_conv1d_fn = (
|
|
268
|
-
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
|
269
|
-
if causal_conv1d is not None
|
|
270
|
-
else (None, None)
|
|
271
|
-
)
|
|
272
269
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
|
273
270
|
|
|
274
271
|
if attention_mask is not None:
|
|
@@ -432,12 +429,6 @@ class MambaMixer(nn.Module):
|
|
|
432
429
|
cache_position: Optional[torch.LongTensor] = None,
|
|
433
430
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
434
431
|
):
|
|
435
|
-
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
436
|
-
causal_conv1d_update, causal_conv1d_fn = (
|
|
437
|
-
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
|
438
|
-
if causal_conv1d is not None
|
|
439
|
-
else (None, None)
|
|
440
|
-
)
|
|
441
432
|
is_fast_path_available = all(
|
|
442
433
|
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
|
443
434
|
)
|
|
@@ -640,6 +631,7 @@ class MambaModel(MambaPreTrainedModel):
|
|
|
640
631
|
return_dict: Optional[bool] = None,
|
|
641
632
|
cache_position: Optional[torch.LongTensor] = None,
|
|
642
633
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
634
|
+
**kwargs,
|
|
643
635
|
) -> Union[tuple, MambaOutput]:
|
|
644
636
|
r"""
|
|
645
637
|
cache_params (`MambaCache`, *optional*):
|
|
@@ -504,6 +504,7 @@ class MarianEncoder(MarianPreTrainedModel):
|
|
|
504
504
|
output_attentions: Optional[bool] = None,
|
|
505
505
|
output_hidden_states: Optional[bool] = None,
|
|
506
506
|
return_dict: Optional[bool] = None,
|
|
507
|
+
**kwargs,
|
|
507
508
|
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
|
508
509
|
r"""
|
|
509
510
|
Args:
|
|
@@ -645,6 +646,7 @@ class MarianDecoder(MarianPreTrainedModel):
|
|
|
645
646
|
output_hidden_states: Optional[bool] = None,
|
|
646
647
|
return_dict: Optional[bool] = None,
|
|
647
648
|
cache_position: Optional[torch.Tensor] = None,
|
|
649
|
+
**kwargs,
|
|
648
650
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
|
649
651
|
r"""
|
|
650
652
|
Args:
|
|
@@ -925,6 +927,7 @@ class MarianModel(MarianPreTrainedModel):
|
|
|
925
927
|
output_hidden_states: Optional[bool] = None,
|
|
926
928
|
return_dict: Optional[bool] = None,
|
|
927
929
|
cache_position: Optional[torch.Tensor] = None,
|
|
930
|
+
**kwargs,
|
|
928
931
|
) -> Seq2SeqModelOutput:
|
|
929
932
|
r"""
|
|
930
933
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1140,6 +1143,7 @@ class MarianMTModel(MarianPreTrainedModel, GenerationMixin):
|
|
|
1140
1143
|
output_hidden_states: Optional[bool] = None,
|
|
1141
1144
|
return_dict: Optional[bool] = None,
|
|
1142
1145
|
cache_position: Optional[torch.Tensor] = None,
|
|
1146
|
+
**kwargs,
|
|
1143
1147
|
) -> Seq2SeqLMOutput:
|
|
1144
1148
|
r"""
|
|
1145
1149
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1288,6 +1292,7 @@ class MarianForCausalLM(MarianPreTrainedModel, GenerationMixin):
|
|
|
1288
1292
|
return_dict: Optional[bool] = None,
|
|
1289
1293
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1290
1294
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1295
|
+
**kwargs,
|
|
1291
1296
|
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
|
|
1292
1297
|
r"""
|
|
1293
1298
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -562,6 +562,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
|
|
|
562
562
|
output_attentions: Optional[bool] = None,
|
|
563
563
|
output_hidden_states: Optional[bool] = None,
|
|
564
564
|
return_dict: Optional[bool] = None,
|
|
565
|
+
**kwargs,
|
|
565
566
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
|
566
567
|
r"""
|
|
567
568
|
xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
|
|
@@ -669,6 +670,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
|
|
|
669
670
|
output_attentions: Optional[bool] = None,
|
|
670
671
|
output_hidden_states: Optional[bool] = None,
|
|
671
672
|
return_dict: Optional[bool] = None,
|
|
673
|
+
**kwargs,
|
|
672
674
|
) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
|
673
675
|
r"""
|
|
674
676
|
xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
|
|
@@ -784,6 +786,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
|
|
|
784
786
|
output_attentions: Optional[bool] = None,
|
|
785
787
|
output_hidden_states: Optional[bool] = None,
|
|
786
788
|
return_dict: Optional[bool] = None,
|
|
789
|
+
**kwargs,
|
|
787
790
|
) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
|
|
788
791
|
r"""
|
|
789
792
|
xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
|
|
@@ -886,6 +889,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
|
|
|
886
889
|
output_attentions: Optional[bool] = None,
|
|
887
890
|
output_hidden_states: Optional[bool] = None,
|
|
888
891
|
return_dict: Optional[bool] = None,
|
|
892
|
+
**kwargs,
|
|
889
893
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
|
|
890
894
|
r"""
|
|
891
895
|
xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
|
|
@@ -101,10 +101,10 @@ class MarkupLMTokenizer(TokenizersBackend):
|
|
|
101
101
|
Users should refer to this superclass for more information regarding those methods.
|
|
102
102
|
|
|
103
103
|
Args:
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
104
|
+
vocab (`str` or `dict[str, int]`, *optional*):
|
|
105
|
+
Custom vocabulary dictionary. If not provided, the vocabulary is loaded from `vocab_file`.
|
|
106
|
+
merges (`str` or `list[str]`, *optional*):
|
|
107
|
+
Custom merges list. If not provided, merges are loaded from `merges_file`.
|
|
108
108
|
errors (`str`, *optional*, defaults to `"replace"`):
|
|
109
109
|
Paradigm to follow when decoding bytes to UTF-8. See
|
|
110
110
|
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
|
@@ -149,12 +149,14 @@ class MarkupLMTokenizer(TokenizersBackend):
|
|
|
149
149
|
"""
|
|
150
150
|
|
|
151
151
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
152
|
+
model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
|
|
153
|
+
model = BPE
|
|
152
154
|
|
|
153
155
|
def __init__(
|
|
154
156
|
self,
|
|
155
157
|
tags_dict,
|
|
156
|
-
vocab: Optional[Union[dict, list]] = None,
|
|
157
|
-
merges: Optional[list] = None,
|
|
158
|
+
vocab: Optional[Union[str, dict[str, int], list[tuple[str, float]]]] = None,
|
|
159
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
158
160
|
errors="replace",
|
|
159
161
|
bos_token="<s>",
|
|
160
162
|
eos_token="</s>",
|
|
@@ -172,57 +174,28 @@ class MarkupLMTokenizer(TokenizersBackend):
|
|
|
172
174
|
trim_offsets=False,
|
|
173
175
|
**kwargs,
|
|
174
176
|
):
|
|
175
|
-
if kwargs.get("from_slow"):
|
|
176
|
-
logger.warning(
|
|
177
|
-
"MarkupLMTokenizer no longer supports initialization from a slow tokenizer. Ignoring `from_slow=True`."
|
|
178
|
-
)
|
|
179
|
-
kwargs["from_slow"] = False
|
|
180
177
|
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
|
181
178
|
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
|
182
179
|
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
|
|
183
180
|
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
|
|
184
181
|
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
|
185
182
|
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
|
186
|
-
|
|
187
183
|
# Mask token behave like a normal word, i.e. include the space before it
|
|
188
184
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
|
189
185
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
if isinstance(processed_vocab, list):
|
|
194
|
-
processed_vocab = {
|
|
195
|
-
token: index for index, (token, _score) in enumerate(processed_vocab) if isinstance(token, str)
|
|
196
|
-
}
|
|
197
|
-
elif isinstance(processed_vocab, dict):
|
|
198
|
-
processed_vocab = {str(token): int(index) for token, index in processed_vocab.items()}
|
|
199
|
-
|
|
200
|
-
if processed_vocab is None:
|
|
201
|
-
processed_vocab = {
|
|
186
|
+
if vocab is None:
|
|
187
|
+
vocab = {
|
|
202
188
|
str(pad_token): 0,
|
|
203
189
|
str(unk_token): 1,
|
|
204
190
|
str(cls_token): 2,
|
|
205
191
|
str(sep_token): 3,
|
|
206
192
|
str(mask_token): 4,
|
|
207
193
|
}
|
|
208
|
-
|
|
209
|
-
normalized_merges = []
|
|
210
|
-
if processed_merges is not None:
|
|
211
|
-
for merge in processed_merges:
|
|
212
|
-
if isinstance(merge, tuple) and len(merge) == 2:
|
|
213
|
-
normalized_merges.append((merge[0], merge[1]))
|
|
214
|
-
elif isinstance(merge, list) and len(merge) == 2:
|
|
215
|
-
normalized_merges.append((merge[0], merge[1]))
|
|
216
|
-
elif isinstance(merge, str):
|
|
217
|
-
parts = merge.split()
|
|
218
|
-
if len(parts) == 2 and not merge.startswith("#"):
|
|
219
|
-
normalized_merges.append((parts[0], parts[1]))
|
|
220
|
-
processed_merges = normalized_merges if normalized_merges else []
|
|
221
|
-
|
|
194
|
+
merges = merges or []
|
|
222
195
|
tokenizer = Tokenizer(
|
|
223
196
|
BPE(
|
|
224
|
-
vocab=
|
|
225
|
-
merges=
|
|
197
|
+
vocab=vocab,
|
|
198
|
+
merges=merges,
|
|
226
199
|
dropout=None,
|
|
227
200
|
continuing_subword_prefix="",
|
|
228
201
|
end_of_word_suffix="",
|
|
@@ -231,21 +204,11 @@ class MarkupLMTokenizer(TokenizersBackend):
|
|
|
231
204
|
)
|
|
232
205
|
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
|
|
233
206
|
tokenizer.decoder = decoders.ByteLevel()
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
tokenizer.post_processor = processors.RobertaProcessing(
|
|
238
|
-
sep=(sep_token_str, processed_vocab.get(sep_token_str, processed_vocab.get("</s>", 2))),
|
|
239
|
-
cls=(cls_token_str, processed_vocab.get(cls_token_str, processed_vocab.get("<s>", 0))),
|
|
240
|
-
add_prefix_space=add_prefix_space,
|
|
241
|
-
trim_offsets=trim_offsets,
|
|
242
|
-
)
|
|
243
|
-
|
|
207
|
+
self._vocab = vocab
|
|
208
|
+
self._merges = merges
|
|
209
|
+
self._tokenizer = tokenizer
|
|
244
210
|
super().__init__(
|
|
245
|
-
tokenizer_object=tokenizer,
|
|
246
211
|
tags_dict=tags_dict,
|
|
247
|
-
vocab=vocab,
|
|
248
|
-
merges=merges,
|
|
249
212
|
errors=errors,
|
|
250
213
|
bos_token=bos_token,
|
|
251
214
|
eos_token=eos_token,
|
|
@@ -263,14 +226,18 @@ class MarkupLMTokenizer(TokenizersBackend):
|
|
|
263
226
|
only_label_first_subword=only_label_first_subword,
|
|
264
227
|
**kwargs,
|
|
265
228
|
)
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
229
|
+
sep_token_str = str(sep_token)
|
|
230
|
+
cls_token_str = str(cls_token)
|
|
231
|
+
cls_token_id = self.cls_token_id
|
|
232
|
+
sep_token_id = self.sep_token_id
|
|
233
|
+
self._tokenizer.post_processor = processors.TemplateProcessing(
|
|
234
|
+
single=f"{cls_token_str} $A {sep_token_str}",
|
|
235
|
+
pair=f"{cls_token_str} $A {sep_token_str} $B {sep_token_str}",
|
|
236
|
+
special_tokens=[
|
|
237
|
+
(cls_token_str, cls_token_id),
|
|
238
|
+
(sep_token_str, sep_token_id),
|
|
239
|
+
],
|
|
240
|
+
)
|
|
274
241
|
|
|
275
242
|
self.tags_dict = tags_dict
|
|
276
243
|
|
|
@@ -2184,6 +2184,7 @@ class Mask2FormerModel(Mask2FormerPreTrainedModel):
|
|
|
2184
2184
|
output_hidden_states: Optional[bool] = None,
|
|
2185
2185
|
output_attentions: Optional[bool] = None,
|
|
2186
2186
|
return_dict: Optional[bool] = None,
|
|
2187
|
+
**kwargs,
|
|
2187
2188
|
) -> Mask2FormerModelOutput:
|
|
2188
2189
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
2189
2190
|
output_hidden_states = (
|
|
@@ -2305,6 +2306,7 @@ class Mask2FormerForUniversalSegmentation(Mask2FormerPreTrainedModel):
|
|
|
2305
2306
|
output_auxiliary_logits: Optional[bool] = None,
|
|
2306
2307
|
output_attentions: Optional[bool] = None,
|
|
2307
2308
|
return_dict: Optional[bool] = None,
|
|
2309
|
+
**kwargs,
|
|
2308
2310
|
) -> Mask2FormerForUniversalSegmentationOutput:
|
|
2309
2311
|
r"""
|
|
2310
2312
|
mask_labels (`list[torch.Tensor]`, *optional*):
|
|
@@ -1496,6 +1496,7 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
|
|
|
1496
1496
|
output_hidden_states: Optional[bool] = None,
|
|
1497
1497
|
output_attentions: Optional[bool] = None,
|
|
1498
1498
|
return_dict: Optional[bool] = None,
|
|
1499
|
+
**kwargs,
|
|
1499
1500
|
) -> MaskFormerModelOutput:
|
|
1500
1501
|
r"""
|
|
1501
1502
|
Examples:
|
|
@@ -1667,6 +1668,7 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
|
|
|
1667
1668
|
output_hidden_states: Optional[bool] = None,
|
|
1668
1669
|
output_attentions: Optional[bool] = None,
|
|
1669
1670
|
return_dict: Optional[bool] = None,
|
|
1671
|
+
**kwargs,
|
|
1670
1672
|
) -> MaskFormerForInstanceSegmentationOutput:
|
|
1671
1673
|
r"""
|
|
1672
1674
|
mask_labels (`list[torch.Tensor]`, *optional*):
|
|
@@ -738,6 +738,7 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
|
|
|
738
738
|
output_hidden_states=None,
|
|
739
739
|
interpolate_pos_encoding=False,
|
|
740
740
|
return_dict=None,
|
|
741
|
+
**kwargs,
|
|
741
742
|
):
|
|
742
743
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
743
744
|
output_hidden_states = (
|
|
@@ -815,6 +816,7 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
|
|
|
815
816
|
output_hidden_states: Optional[bool] = None,
|
|
816
817
|
output_attentions: Optional[bool] = None,
|
|
817
818
|
return_dict: Optional[bool] = None,
|
|
819
|
+
**kwargs,
|
|
818
820
|
) -> BackboneOutput:
|
|
819
821
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
820
822
|
output_hidden_states = (
|
|
@@ -540,6 +540,7 @@ class MBartEncoder(MBartPreTrainedModel):
|
|
|
540
540
|
output_attentions: Optional[bool] = None,
|
|
541
541
|
output_hidden_states: Optional[bool] = None,
|
|
542
542
|
return_dict: Optional[bool] = None,
|
|
543
|
+
**kwargs,
|
|
543
544
|
) -> Union[tuple, BaseModelOutput]:
|
|
544
545
|
r"""
|
|
545
546
|
Args:
|
|
@@ -691,6 +692,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
|
|
691
692
|
output_hidden_states: Optional[bool] = None,
|
|
692
693
|
return_dict: Optional[bool] = None,
|
|
693
694
|
cache_position: Optional[torch.Tensor] = None,
|
|
695
|
+
**kwargs,
|
|
694
696
|
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
695
697
|
r"""
|
|
696
698
|
Args:
|
|
@@ -919,6 +921,7 @@ class MBartModel(MBartPreTrainedModel):
|
|
|
919
921
|
output_hidden_states: Optional[bool] = None,
|
|
920
922
|
return_dict: Optional[bool] = None,
|
|
921
923
|
cache_position: Optional[torch.Tensor] = None,
|
|
924
|
+
**kwargs,
|
|
922
925
|
) -> Union[Seq2SeqModelOutput, tuple[torch.FloatTensor]]:
|
|
923
926
|
r"""
|
|
924
927
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1052,6 +1055,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel, GenerationMixin):
|
|
|
1052
1055
|
output_hidden_states: Optional[bool] = None,
|
|
1053
1056
|
return_dict: Optional[bool] = None,
|
|
1054
1057
|
cache_position: Optional[torch.Tensor] = None,
|
|
1058
|
+
**kwargs,
|
|
1055
1059
|
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
|
|
1056
1060
|
r"""
|
|
1057
1061
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1205,6 +1209,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
|
|
1205
1209
|
output_hidden_states: Optional[bool] = None,
|
|
1206
1210
|
return_dict: Optional[bool] = None,
|
|
1207
1211
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1212
|
+
**kwargs,
|
|
1208
1213
|
) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
|
|
1209
1214
|
r"""
|
|
1210
1215
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1338,6 +1343,7 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
|
|
|
1338
1343
|
output_hidden_states: Optional[bool] = None,
|
|
1339
1344
|
return_dict: Optional[bool] = None,
|
|
1340
1345
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1346
|
+
**kwargs,
|
|
1341
1347
|
) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]:
|
|
1342
1348
|
r"""
|
|
1343
1349
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1480,6 +1486,7 @@ class MBartForCausalLM(MBartPreTrainedModel, GenerationMixin):
|
|
|
1480
1486
|
return_dict: Optional[bool] = None,
|
|
1481
1487
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1482
1488
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1489
|
+
**kwargs,
|
|
1483
1490
|
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
|
|
1484
1491
|
r"""
|
|
1485
1492
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from typing import Optional
|
|
16
|
+
from typing import Optional, Union
|
|
17
17
|
|
|
18
18
|
from tokenizers import Tokenizer, decoders, pre_tokenizers, processors
|
|
19
19
|
from tokenizers.models import Unigram
|
|
@@ -58,13 +58,14 @@ class MBartTokenizer(TokenizersBackend):
|
|
|
58
58
|
|
|
59
59
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
60
60
|
model_input_names = ["input_ids", "attention_mask"]
|
|
61
|
-
|
|
61
|
+
model = Unigram
|
|
62
62
|
|
|
63
63
|
prefix_tokens: list[int] = []
|
|
64
64
|
suffix_tokens: list[int] = []
|
|
65
65
|
|
|
66
66
|
def __init__(
|
|
67
67
|
self,
|
|
68
|
+
vocab: Optional[Union[str, dict, list]] = None,
|
|
68
69
|
bos_token="<s>",
|
|
69
70
|
eos_token="</s>",
|
|
70
71
|
sep_token="</s>",
|
|
@@ -75,9 +76,6 @@ class MBartTokenizer(TokenizersBackend):
|
|
|
75
76
|
src_lang=None,
|
|
76
77
|
tgt_lang=None,
|
|
77
78
|
additional_special_tokens=None,
|
|
78
|
-
vocab=None,
|
|
79
|
-
merges=None, # Ignored for Unigram
|
|
80
|
-
vocab_file=None,
|
|
81
79
|
**kwargs,
|
|
82
80
|
):
|
|
83
81
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
|
@@ -88,56 +86,20 @@ class MBartTokenizer(TokenizersBackend):
|
|
|
88
86
|
[t for t in additional_special_tokens if t not in _additional_special_tokens]
|
|
89
87
|
)
|
|
90
88
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
# Handle different vocab formats (dict, list of tokens, or list of tuples)
|
|
94
|
-
# SentencePieceExtractor returns list[tuple[str, float]] which is the expected format
|
|
95
|
-
if isinstance(vocab, dict):
|
|
96
|
-
vocab = [(token, 0.0) for token in vocab.keys()]
|
|
97
|
-
elif isinstance(vocab, list) and len(vocab) > 0:
|
|
98
|
-
if not isinstance(vocab[0], tuple):
|
|
99
|
-
vocab = [(token, 0.0) for token in vocab]
|
|
100
|
-
else:
|
|
101
|
-
# Ensure tuples are (str, float) format
|
|
102
|
-
vocab = [(str(item[0]), float(item[1])) for item in vocab]
|
|
103
|
-
|
|
104
|
-
# Reorder to fairseq: <s>, <pad>, </s>, <unk>, ... (rest of vocab from SPM[3:])
|
|
105
|
-
vocab_list = []
|
|
106
|
-
vocab_list.append((str(bos_token), 0.0))
|
|
107
|
-
vocab_list.append((str(pad_token), 0.0))
|
|
108
|
-
vocab_list.append((str(eos_token), 0.0))
|
|
109
|
-
vocab_list.append((str(unk_token), 0.0))
|
|
110
|
-
|
|
111
|
-
# Add the rest of the SentencePiece vocab (skipping first 3: <unk>, <s>, </s>)
|
|
112
|
-
vocab_list.extend(vocab[4:])
|
|
113
|
-
|
|
114
|
-
# Add language codes
|
|
115
|
-
for lang_code in FAIRSEQ_LANGUAGE_CODES:
|
|
116
|
-
vocab_list.append((str(lang_code), 0.0))
|
|
117
|
-
|
|
118
|
-
# Add mask token
|
|
119
|
-
vocab_list.append((str(mask_token), 0.0))
|
|
120
|
-
|
|
121
|
-
self._vocab_scores = vocab_list
|
|
122
|
-
else:
|
|
123
|
-
self._vocab_scores = [
|
|
89
|
+
if vocab is None:
|
|
90
|
+
vocab = [
|
|
124
91
|
(str(bos_token), 0.0),
|
|
125
92
|
(str(pad_token), 0.0),
|
|
126
93
|
(str(eos_token), 0.0),
|
|
127
94
|
(str(unk_token), 0.0),
|
|
128
|
-
("▁", -2.0),
|
|
129
95
|
]
|
|
96
|
+
vocab += [("▁", -2.0)]
|
|
130
97
|
for lang_code in FAIRSEQ_LANGUAGE_CODES:
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
self.
|
|
135
|
-
|
|
136
|
-
self._vocab_scores,
|
|
137
|
-
unk_id=3,
|
|
138
|
-
byte_fallback=False,
|
|
139
|
-
)
|
|
140
|
-
)
|
|
98
|
+
vocab.append((lang_code, 0.0))
|
|
99
|
+
vocab.append((str(mask_token), 0.0))
|
|
100
|
+
|
|
101
|
+
self._vocab = vocab
|
|
102
|
+
self._tokenizer = Tokenizer(Unigram(self._vocab, unk_id=3, byte_fallback=False))
|
|
141
103
|
|
|
142
104
|
self._tokenizer.normalizer = None
|
|
143
105
|
|
|
@@ -150,10 +112,7 @@ class MBartTokenizer(TokenizersBackend):
|
|
|
150
112
|
|
|
151
113
|
self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
|
|
152
114
|
|
|
153
|
-
tokenizer_object = self._tokenizer
|
|
154
|
-
|
|
155
115
|
super().__init__(
|
|
156
|
-
tokenizer_object=tokenizer_object,
|
|
157
116
|
bos_token=bos_token,
|
|
158
117
|
eos_token=eos_token,
|
|
159
118
|
sep_token=sep_token,
|