transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
transformers/modeling_utils.py
CHANGED
|
@@ -36,7 +36,7 @@ from typing import Optional, TypeVar, Union, get_type_hints
|
|
|
36
36
|
from zipfile import is_zipfile
|
|
37
37
|
|
|
38
38
|
import torch
|
|
39
|
-
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
|
|
39
|
+
from huggingface_hub import create_repo, is_offline_mode, split_torch_state_dict_into_shards
|
|
40
40
|
from packaging import version
|
|
41
41
|
from safetensors import safe_open
|
|
42
42
|
from safetensors.torch import save_file as safe_save_file
|
|
@@ -85,7 +85,7 @@ from .integrations.tensor_parallel import (
|
|
|
85
85
|
verify_tp_plan,
|
|
86
86
|
)
|
|
87
87
|
from .loss.loss_utils import LOSS_MAPPING
|
|
88
|
-
from .modeling_flash_attention_utils import lazy_import_flash_attention
|
|
88
|
+
from .modeling_flash_attention_utils import lazy_import_flash_attention, lazy_import_paged_flash_attention
|
|
89
89
|
from .pytorch_utils import id_tensor_storage
|
|
90
90
|
from .quantizers import HfQuantizer
|
|
91
91
|
from .quantizers.auto import get_hf_quantizer
|
|
@@ -93,7 +93,6 @@ from .quantizers.quantizers_utils import get_module_from_name
|
|
|
93
93
|
from .safetensors_conversion import auto_conversion
|
|
94
94
|
from .utils import (
|
|
95
95
|
ADAPTER_SAFE_WEIGHTS_NAME,
|
|
96
|
-
ADAPTER_WEIGHTS_NAME,
|
|
97
96
|
DUMMY_INPUTS,
|
|
98
97
|
SAFE_WEIGHTS_INDEX_NAME,
|
|
99
98
|
SAFE_WEIGHTS_NAME,
|
|
@@ -110,7 +109,6 @@ from .utils import (
|
|
|
110
109
|
is_flash_attn_2_available,
|
|
111
110
|
is_flash_attn_3_available,
|
|
112
111
|
is_kernels_available,
|
|
113
|
-
is_offline_mode,
|
|
114
112
|
is_torch_flex_attn_available,
|
|
115
113
|
is_torch_greater_or_equal,
|
|
116
114
|
is_torch_mlu_available,
|
|
@@ -279,7 +277,9 @@ def get_state_dict_dtype(state_dict):
|
|
|
279
277
|
return t.dtype
|
|
280
278
|
|
|
281
279
|
# if no floating dtype was found return whatever the first dtype is
|
|
282
|
-
|
|
280
|
+
if len(state_dict) == 0:
|
|
281
|
+
return torch.float32
|
|
282
|
+
return next(iter(state_dict.values())).dtype
|
|
283
283
|
|
|
284
284
|
|
|
285
285
|
str_to_torch_dtype = {
|
|
@@ -552,8 +552,7 @@ def _get_resolved_checkpoint_files(
|
|
|
552
552
|
raise OSError(
|
|
553
553
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
554
554
|
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
|
|
555
|
-
"and thus cannot be loaded with `safetensors`. Please
|
|
556
|
-
"been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
|
|
555
|
+
"and thus cannot be loaded with `safetensors`. Please do not set `use_safetensors=True`."
|
|
557
556
|
)
|
|
558
557
|
else:
|
|
559
558
|
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
|
@@ -772,7 +771,7 @@ def _get_dtype(
|
|
|
772
771
|
for key in config.sub_configs:
|
|
773
772
|
if (sub_config := getattr(config, key)) is not None:
|
|
774
773
|
sub_config.dtype = default_dtype
|
|
775
|
-
|
|
774
|
+
dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
|
776
775
|
return config, dtype, dtype_orig
|
|
777
776
|
|
|
778
777
|
|
|
@@ -799,7 +798,11 @@ class ModuleUtilsMixin:
|
|
|
799
798
|
"""
|
|
800
799
|
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
|
801
800
|
"""
|
|
802
|
-
|
|
801
|
+
dtype = self._dtype or next(param.dtype for param in self.parameters() if param.is_floating_point())
|
|
802
|
+
if isinstance(dtype, str):
|
|
803
|
+
if hasattr(torch, dtype):
|
|
804
|
+
dtype = getattr(torch, dtype)
|
|
805
|
+
return dtype
|
|
803
806
|
|
|
804
807
|
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
|
805
808
|
"""
|
|
@@ -1078,6 +1081,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1078
1081
|
_keep_in_fp32_modules_strict = None
|
|
1079
1082
|
|
|
1080
1083
|
dtype_plan: Optional[dict[str, torch.dtype]] = None
|
|
1084
|
+
_dtype: Optional[Union[str, torch.dtype]] = torch.get_default_dtype()
|
|
1081
1085
|
|
|
1082
1086
|
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
|
1083
1087
|
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
|
@@ -1222,6 +1226,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1222
1226
|
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
|
1223
1227
|
)
|
|
1224
1228
|
self.config = config
|
|
1229
|
+
default_dtype = torch.get_default_dtype()
|
|
1230
|
+
self._dtype = default_dtype
|
|
1225
1231
|
|
|
1226
1232
|
# Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
|
|
1227
1233
|
# setting it recursively)
|
|
@@ -1460,6 +1466,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1460
1466
|
Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
|
|
1461
1467
|
`torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
|
|
1462
1468
|
"""
|
|
1469
|
+
if isinstance(dtype, str):
|
|
1470
|
+
if hasattr(torch, dtype):
|
|
1471
|
+
dtype = getattr(torch, dtype)
|
|
1472
|
+
else:
|
|
1473
|
+
raise ValueError(f"Received an invalid string dtype: {dtype}")
|
|
1463
1474
|
if not dtype.is_floating_point:
|
|
1464
1475
|
raise ValueError(
|
|
1465
1476
|
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
|
|
@@ -1468,6 +1479,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1468
1479
|
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
|
|
1469
1480
|
dtype_orig = torch.get_default_dtype()
|
|
1470
1481
|
torch.set_default_dtype(dtype)
|
|
1482
|
+
cls._dtype = dtype
|
|
1471
1483
|
return dtype_orig
|
|
1472
1484
|
|
|
1473
1485
|
@property
|
|
@@ -1764,9 +1776,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1764
1776
|
"""
|
|
1765
1777
|
applicable_attn_implementation = attn_implementation
|
|
1766
1778
|
|
|
1779
|
+
is_paged = attn_implementation is not None and attn_implementation.startswith("paged|")
|
|
1780
|
+
|
|
1767
1781
|
# If FA not installed, do not fail but use kernels instead
|
|
1768
1782
|
requested_original_flash_attn = attn_implementation is not None and (
|
|
1769
|
-
attn_implementation
|
|
1783
|
+
attn_implementation.removeprefix("paged|") == "flash_attention_2"
|
|
1784
|
+
or attn_implementation.removeprefix("paged|") == "flash_attention_3"
|
|
1770
1785
|
)
|
|
1771
1786
|
if (
|
|
1772
1787
|
requested_original_flash_attn
|
|
@@ -1784,10 +1799,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1784
1799
|
else:
|
|
1785
1800
|
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
|
|
1786
1801
|
|
|
1802
|
+
if is_paged:
|
|
1803
|
+
applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
|
|
1804
|
+
|
|
1787
1805
|
if is_kernel(applicable_attn_implementation):
|
|
1788
1806
|
try:
|
|
1789
1807
|
# preload flash attention here to allow compile with fullgraph
|
|
1790
|
-
|
|
1808
|
+
if is_paged:
|
|
1809
|
+
lazy_import_paged_flash_attention(applicable_attn_implementation)
|
|
1810
|
+
else:
|
|
1811
|
+
lazy_import_flash_attention(applicable_attn_implementation)
|
|
1791
1812
|
|
|
1792
1813
|
# log that we used kernel fallback if successful
|
|
1793
1814
|
if requested_original_flash_attn:
|
|
@@ -2104,7 +2125,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2104
2125
|
possible_module_names = ["language_model", "text_model", "decoder"]
|
|
2105
2126
|
for name in possible_module_names:
|
|
2106
2127
|
if hasattr(self, name):
|
|
2107
|
-
print(name)
|
|
2108
2128
|
setattr(self, name, decoder)
|
|
2109
2129
|
return
|
|
2110
2130
|
|
|
@@ -3002,10 +3022,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3002
3022
|
save_directory: Union[str, os.PathLike],
|
|
3003
3023
|
is_main_process: bool = True,
|
|
3004
3024
|
state_dict: Optional[dict] = None,
|
|
3005
|
-
save_function: Callable = torch.save,
|
|
3006
3025
|
push_to_hub: bool = False,
|
|
3007
|
-
max_shard_size: Union[int, str] = "
|
|
3008
|
-
safe_serialization: bool = True,
|
|
3026
|
+
max_shard_size: Union[int, str] = "50GB",
|
|
3009
3027
|
variant: Optional[str] = None,
|
|
3010
3028
|
token: Optional[Union[str, bool]] = None,
|
|
3011
3029
|
save_peft_format: bool = True,
|
|
@@ -3027,18 +3045,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3027
3045
|
The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
|
|
3028
3046
|
save parts of the model or if special precautions need to be taken when recovering the state dictionary
|
|
3029
3047
|
of a model (like when using model parallelism).
|
|
3030
|
-
save_function (`Callable`):
|
|
3031
|
-
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
|
3032
|
-
need to replace `torch.save` by another method.
|
|
3033
3048
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
3034
3049
|
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
|
3035
3050
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
|
3036
3051
|
namespace).
|
|
3037
|
-
max_shard_size (`int` or `str`, *optional*, defaults to `"
|
|
3052
|
+
max_shard_size (`int` or `str`, *optional*, defaults to `"50GB"`):
|
|
3038
3053
|
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
|
3039
3054
|
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
|
3040
|
-
We default it to 5GB in order for models to be able to run easily on free-tier google colab instances
|
|
3041
|
-
without CPU OOM issues.
|
|
3042
3055
|
|
|
3043
3056
|
<Tip warning={true}>
|
|
3044
3057
|
|
|
@@ -3047,10 +3060,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3047
3060
|
|
|
3048
3061
|
</Tip>
|
|
3049
3062
|
|
|
3050
|
-
safe_serialization (`bool`, *optional*, defaults to `True`):
|
|
3051
|
-
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
|
3052
3063
|
variant (`str`, *optional*):
|
|
3053
|
-
If specified, weights are saved in the format
|
|
3064
|
+
If specified, weights are saved in the format model.<variant>.safetensors.
|
|
3054
3065
|
token (`str` or `bool`, *optional*):
|
|
3055
3066
|
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
|
3056
3067
|
the token generated when running `hf auth login` (stored in `~/.huggingface`).
|
|
@@ -3072,9 +3083,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3072
3083
|
|
|
3073
3084
|
hf_quantizer = getattr(self, "hf_quantizer", None)
|
|
3074
3085
|
quantization_serializable = (
|
|
3075
|
-
hf_quantizer is not None
|
|
3076
|
-
and isinstance(hf_quantizer, HfQuantizer)
|
|
3077
|
-
and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
|
|
3086
|
+
hf_quantizer is not None and isinstance(hf_quantizer, HfQuantizer) and hf_quantizer.is_serializable()
|
|
3078
3087
|
)
|
|
3079
3088
|
|
|
3080
3089
|
if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
|
|
@@ -3110,7 +3119,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3110
3119
|
|
|
3111
3120
|
metadata = {}
|
|
3112
3121
|
if hf_quantizer is not None:
|
|
3113
|
-
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self
|
|
3122
|
+
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self)
|
|
3114
3123
|
metadata["format"] = "pt"
|
|
3115
3124
|
|
|
3116
3125
|
# Only save the model itself if we are using distributed training
|
|
@@ -3202,75 +3211,72 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3202
3211
|
if self._tp_size is not None:
|
|
3203
3212
|
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
|
|
3204
3213
|
|
|
3205
|
-
|
|
3206
|
-
|
|
3207
|
-
|
|
3208
|
-
|
|
3209
|
-
|
|
3210
|
-
|
|
3211
|
-
|
|
3212
|
-
|
|
3213
|
-
|
|
3214
|
-
|
|
3215
|
-
|
|
3216
|
-
|
|
3217
|
-
|
|
3218
|
-
|
|
3219
|
-
|
|
3220
|
-
# (state_dict tensors are detached and therefore no longer shared)
|
|
3221
|
-
tensor = self.get_parameter(name)
|
|
3222
|
-
ptrs[id(tensor)].append(name)
|
|
3214
|
+
# Safetensors does not allow tensor aliasing - we're going to remove aliases before saving
|
|
3215
|
+
ptrs = collections.defaultdict(list)
|
|
3216
|
+
for name, tensor in state_dict.items():
|
|
3217
|
+
if not isinstance(tensor, torch.Tensor):
|
|
3218
|
+
# Sometimes in the state_dict we have non-tensor objects.
|
|
3219
|
+
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
|
3220
|
+
# In the non-tensor case, fall back to the pointer of the object itself
|
|
3221
|
+
ptrs[id(tensor)].append(name)
|
|
3222
|
+
|
|
3223
|
+
elif tensor.device.type == "meta":
|
|
3224
|
+
# In offloaded cases, there may be meta tensors in the state_dict.
|
|
3225
|
+
# For these cases, key by the pointer of the original tensor object
|
|
3226
|
+
# (state_dict tensors are detached and therefore no longer shared)
|
|
3227
|
+
tensor = self.get_parameter(name)
|
|
3228
|
+
ptrs[id(tensor)].append(name)
|
|
3223
3229
|
|
|
3224
|
-
|
|
3225
|
-
|
|
3226
|
-
|
|
3227
|
-
|
|
3228
|
-
|
|
3229
|
-
|
|
3230
|
-
|
|
3231
|
-
|
|
3232
|
-
|
|
3233
|
-
|
|
3234
|
-
|
|
3235
|
-
|
|
3236
|
-
|
|
3237
|
-
|
|
3238
|
-
|
|
3239
|
-
|
|
3240
|
-
|
|
3241
|
-
|
|
3242
|
-
|
|
3243
|
-
|
|
3244
|
-
|
|
3245
|
-
|
|
3246
|
-
|
|
3247
|
-
|
|
3248
|
-
|
|
3249
|
-
|
|
3250
|
-
|
|
3251
|
-
|
|
3252
|
-
|
|
3253
|
-
|
|
3254
|
-
|
|
3255
|
-
|
|
3256
|
-
|
|
3257
|
-
|
|
3258
|
-
|
|
3259
|
-
|
|
3260
|
-
|
|
3261
|
-
|
|
3262
|
-
|
|
3263
|
-
|
|
3264
|
-
|
|
3265
|
-
|
|
3266
|
-
|
|
3267
|
-
|
|
3268
|
-
|
|
3269
|
-
|
|
3270
|
-
|
|
3271
|
-
|
|
3272
|
-
|
|
3273
|
-
|
|
3230
|
+
else:
|
|
3231
|
+
ptrs[id_tensor_storage(tensor)].append(name)
|
|
3232
|
+
|
|
3233
|
+
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
|
3234
|
+
|
|
3235
|
+
# Recursively descend to find tied weight keys
|
|
3236
|
+
_tied_weights_keys = set(_get_tied_weight_keys(self))
|
|
3237
|
+
error_names = []
|
|
3238
|
+
to_delete_names = set()
|
|
3239
|
+
for names in shared_ptrs.values():
|
|
3240
|
+
# Removing the keys which are declared as known duplicates on
|
|
3241
|
+
# load. This allows to make sure the name which is kept is consistent.
|
|
3242
|
+
if _tied_weights_keys is not None:
|
|
3243
|
+
found = 0
|
|
3244
|
+
for name in sorted(names):
|
|
3245
|
+
matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
|
|
3246
|
+
if matches_pattern and name in state_dict:
|
|
3247
|
+
found += 1
|
|
3248
|
+
if found < len(names):
|
|
3249
|
+
to_delete_names.add(name)
|
|
3250
|
+
# We are entering a place where the weights and the transformers configuration do NOT match.
|
|
3251
|
+
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
|
|
3252
|
+
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
|
|
3253
|
+
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
|
|
3254
|
+
for name in disjoint_names:
|
|
3255
|
+
state_dict[name] = state_dict[name].clone()
|
|
3256
|
+
|
|
3257
|
+
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
|
3258
|
+
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
|
3259
|
+
# the key back leading to random tensor. A proper warning will be shown
|
|
3260
|
+
# during reload (if applicable), but since the file is not necessarily compatible with
|
|
3261
|
+
# the config, better show a proper warning.
|
|
3262
|
+
shared_names, identical_names = _find_identical(shared_names, state_dict)
|
|
3263
|
+
# delete tensors that have identical storage
|
|
3264
|
+
for inames in identical_names:
|
|
3265
|
+
known = inames.intersection(to_delete_names)
|
|
3266
|
+
for name in known:
|
|
3267
|
+
del state_dict[name]
|
|
3268
|
+
unknown = inames.difference(to_delete_names)
|
|
3269
|
+
if len(unknown) > 1:
|
|
3270
|
+
error_names.append(unknown)
|
|
3271
|
+
|
|
3272
|
+
if shared_names:
|
|
3273
|
+
error_names.extend(shared_names)
|
|
3274
|
+
|
|
3275
|
+
if len(error_names) > 0:
|
|
3276
|
+
raise RuntimeError(
|
|
3277
|
+
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
|
|
3278
|
+
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
|
|
3279
|
+
)
|
|
3274
3280
|
|
|
3275
3281
|
# Revert all renaming and/or weight operations
|
|
3276
3282
|
if save_original_format:
|
|
@@ -3278,10 +3284,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3278
3284
|
|
|
3279
3285
|
# Shard the model if it is too big.
|
|
3280
3286
|
if not _hf_peft_config_loaded:
|
|
3281
|
-
weights_name = SAFE_WEIGHTS_NAME
|
|
3287
|
+
weights_name = SAFE_WEIGHTS_NAME
|
|
3282
3288
|
weights_name = _add_variant(weights_name, variant)
|
|
3283
3289
|
else:
|
|
3284
|
-
weights_name = ADAPTER_SAFE_WEIGHTS_NAME
|
|
3290
|
+
weights_name = ADAPTER_SAFE_WEIGHTS_NAME
|
|
3285
3291
|
|
|
3286
3292
|
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
|
3287
3293
|
state_dict_split = split_torch_state_dict_into_shards(
|
|
@@ -3350,13 +3356,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3350
3356
|
del shard_state_dict
|
|
3351
3357
|
gc.collect()
|
|
3352
3358
|
|
|
3353
|
-
|
|
3354
|
-
|
|
3355
|
-
|
|
3356
|
-
# too much before scheduling the next write when its in a different file
|
|
3357
|
-
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
|
3358
|
-
else:
|
|
3359
|
-
save_function(shard, os.path.join(save_directory, shard_file))
|
|
3359
|
+
# TODO: we should def parallelize this we are otherwise just waiting
|
|
3360
|
+
# too much before scheduling the next write when its in a different file
|
|
3361
|
+
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
|
3360
3362
|
|
|
3361
3363
|
del state_dict
|
|
3362
3364
|
|
|
@@ -3364,7 +3366,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3364
3366
|
path_to_weights = os.path.join(save_directory, weights_name)
|
|
3365
3367
|
logger.info(f"Model weights saved in {path_to_weights}")
|
|
3366
3368
|
else:
|
|
3367
|
-
save_index_file = SAFE_WEIGHTS_INDEX_NAME
|
|
3369
|
+
save_index_file = SAFE_WEIGHTS_INDEX_NAME
|
|
3368
3370
|
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
|
3369
3371
|
# Save the index as well
|
|
3370
3372
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
|
@@ -3835,6 +3837,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3835
3837
|
# For BC on torch_dtype argument
|
|
3836
3838
|
if torch_dtype is not None:
|
|
3837
3839
|
dtype = dtype if dtype is not None else torch_dtype
|
|
3840
|
+
if dtype is None:
|
|
3841
|
+
dtype = "auto"
|
|
3838
3842
|
|
|
3839
3843
|
if is_offline_mode() and not local_files_only:
|
|
3840
3844
|
local_files_only = True
|
|
@@ -4039,7 +4043,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4039
4043
|
hf_quantizer.postprocess_model(model, config=config) # usually a no-op but sometimes needed
|
|
4040
4044
|
|
|
4041
4045
|
if _adapter_model_path is not None:
|
|
4042
|
-
adapter_kwargs["key_mapping"] =
|
|
4046
|
+
adapter_kwargs["key_mapping"] = key_mapping
|
|
4043
4047
|
model.load_adapter(
|
|
4044
4048
|
_adapter_model_path,
|
|
4045
4049
|
adapter_name=adapter_name,
|
|
@@ -4090,10 +4094,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4090
4094
|
# Prepare parameters offloading if needed
|
|
4091
4095
|
if device_map is not None and "disk" in device_map.values():
|
|
4092
4096
|
disk_offload_index = accelerate_disk_offload(
|
|
4097
|
+
model,
|
|
4093
4098
|
disk_offload_folder,
|
|
4094
4099
|
checkpoint_files,
|
|
4095
4100
|
device_map,
|
|
4096
|
-
expected_keys,
|
|
4097
4101
|
sharded_metadata,
|
|
4098
4102
|
dtype,
|
|
4099
4103
|
weight_mapping,
|
|
@@ -4115,7 +4119,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4115
4119
|
state_dict = merged_state_dict
|
|
4116
4120
|
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
|
|
4117
4121
|
# This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
|
|
4118
|
-
missing_keys, unexpected_keys, mismatched_keys,
|
|
4122
|
+
missing_keys, unexpected_keys, mismatched_keys, conversion_errors = set(), set(), set(), set()
|
|
4119
4123
|
else:
|
|
4120
4124
|
all_pointer = set()
|
|
4121
4125
|
# Checkpoints are safetensors
|
|
@@ -4137,7 +4141,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4137
4141
|
else:
|
|
4138
4142
|
raise ValueError("Neither a state dict nor checkpoint files were found.")
|
|
4139
4143
|
|
|
4140
|
-
missing_keys, unexpected_keys, mismatched_keys, disk_offload_index,
|
|
4144
|
+
missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, conversion_errors = (
|
|
4141
4145
|
convert_and_load_state_dict_in_model(
|
|
4142
4146
|
model,
|
|
4143
4147
|
merged_state_dict,
|
|
@@ -4180,7 +4184,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4180
4184
|
tp_device = list(device_map.values())[0]
|
|
4181
4185
|
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
|
|
4182
4186
|
# not part of the state_dict (persistent=False)
|
|
4183
|
-
for buffer in model.buffers(): # TODO to
|
|
4187
|
+
for buffer in model.buffers(): # TODO to avoid this buffer could be added to the ckpt
|
|
4184
4188
|
if buffer.device != tp_device:
|
|
4185
4189
|
buffer.data = buffer.to(tp_device)
|
|
4186
4190
|
|
|
@@ -4211,7 +4215,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4211
4215
|
missing_keys=missing_keys,
|
|
4212
4216
|
mismatched_keys=mismatched_keys,
|
|
4213
4217
|
mismatched_shapes=mismatched_keys,
|
|
4214
|
-
|
|
4218
|
+
conversion_errors=conversion_errors,
|
|
4215
4219
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
|
4216
4220
|
)
|
|
4217
4221
|
|
transformers/models/__init__.py
CHANGED
|
@@ -126,6 +126,7 @@ if TYPE_CHECKING:
|
|
|
126
126
|
from .falcon import *
|
|
127
127
|
from .falcon_h1 import *
|
|
128
128
|
from .falcon_mamba import *
|
|
129
|
+
from .fast_vlm import *
|
|
129
130
|
from .fastspeech2_conformer import *
|
|
130
131
|
from .flaubert import *
|
|
131
132
|
from .flava import *
|
|
@@ -185,6 +186,7 @@ if TYPE_CHECKING:
|
|
|
185
186
|
from .jetmoe import *
|
|
186
187
|
from .kosmos2 import *
|
|
187
188
|
from .kyutai_speech_to_text import *
|
|
189
|
+
from .lasr import *
|
|
188
190
|
from .layoutlm import *
|
|
189
191
|
from .layoutlmv2 import *
|
|
190
192
|
from .layoutlmv3 import *
|
|
@@ -263,6 +265,7 @@ if TYPE_CHECKING:
|
|
|
263
265
|
from .ovis2 import *
|
|
264
266
|
from .owlv2 import *
|
|
265
267
|
from .owlvit import *
|
|
268
|
+
from .paddleocr_vl import *
|
|
266
269
|
from .paligemma import *
|
|
267
270
|
from .parakeet import *
|
|
268
271
|
from .patchtsmixer import *
|
|
@@ -28,7 +28,7 @@ from torch import nn
|
|
|
28
28
|
from ...activations import ACT2FN
|
|
29
29
|
from ...cache_utils import Cache, DynamicCache
|
|
30
30
|
from ...generation import GenerationMixin
|
|
31
|
-
from ...integrations import use_kernel_func_from_hub
|
|
31
|
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
|
32
32
|
from ...integrations.hub_kernels import use_kernel_forward_from_hub
|
|
33
33
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
34
34
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -37,7 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
37
37
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
38
38
|
from ...processing_utils import Unpack
|
|
39
39
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
40
|
-
from ...utils.generic import check_model_inputs
|
|
40
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
41
41
|
from .configuration_afmoe import AfmoeConfig
|
|
42
42
|
|
|
43
43
|
|
|
@@ -97,7 +97,7 @@ class AfmoeRotaryEmbedding(nn.Module):
|
|
|
97
97
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
98
98
|
|
|
99
99
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
100
|
-
with
|
|
100
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
101
101
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
102
102
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
103
103
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -338,6 +338,7 @@ def eager_attention_forward(
|
|
|
338
338
|
return attn_output, attn_weights
|
|
339
339
|
|
|
340
340
|
|
|
341
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
341
342
|
class AfmoeAttention(nn.Module):
|
|
342
343
|
"""
|
|
343
344
|
Multi-headed attention module with optional sliding window and gating.
|
|
@@ -369,7 +370,6 @@ class AfmoeAttention(nn.Module):
|
|
|
369
370
|
self.o_proj = nn.Linear(
|
|
370
371
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
371
372
|
)
|
|
372
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
373
373
|
# Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
|
|
374
374
|
# We only add AFMoE-specific attributes
|
|
375
375
|
self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Tokenization classes for ALBERT model."""
|
|
16
16
|
|
|
17
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
|
20
20
|
from tokenizers.models import Unigram
|
|
@@ -73,8 +73,8 @@ class AlbertTokenizer(TokenizersBackend):
|
|
|
73
73
|
other word.
|
|
74
74
|
trim_offsets (`bool`, *optional*, defaults to `True`):
|
|
75
75
|
Whether the post processing step should trim offsets to avoid including whitespaces.
|
|
76
|
-
vocab (`
|
|
77
|
-
Custom vocabulary
|
|
76
|
+
vocab (`str` or `list[tuple[str, float]]`, *optional*):
|
|
77
|
+
Custom vocabulary with `(token, score)` tuples. If not provided, vocabulary is loaded from `vocab_file`.
|
|
78
78
|
vocab_file (`str`, *optional*):
|
|
79
79
|
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
|
|
80
80
|
contains the vocabulary necessary to instantiate a tokenizer.
|
|
@@ -82,10 +82,11 @@ class AlbertTokenizer(TokenizersBackend):
|
|
|
82
82
|
|
|
83
83
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
84
84
|
model_input_names = ["input_ids", "attention_mask"]
|
|
85
|
-
|
|
85
|
+
model = Unigram
|
|
86
86
|
|
|
87
87
|
def __init__(
|
|
88
88
|
self,
|
|
89
|
+
vocab: Optional[Union[str, list[tuple[str, float]]]] = None,
|
|
89
90
|
do_lower_case: bool = True,
|
|
90
91
|
keep_accents: bool = False,
|
|
91
92
|
bos_token: str = "[CLS]",
|
|
@@ -97,19 +98,15 @@ class AlbertTokenizer(TokenizersBackend):
|
|
|
97
98
|
mask_token: str = "[MASK]",
|
|
98
99
|
add_prefix_space: bool = True,
|
|
99
100
|
trim_offsets: bool = True,
|
|
100
|
-
vocab: Optional[dict] = None,
|
|
101
|
-
vocab_file: Optional[str] = None,
|
|
102
101
|
**kwargs,
|
|
103
102
|
):
|
|
104
|
-
self.vocab_file = vocab_file
|
|
105
103
|
self.add_prefix_space = add_prefix_space
|
|
106
104
|
self.trim_offsets = trim_offsets
|
|
107
|
-
|
|
108
105
|
self.do_lower_case = do_lower_case
|
|
109
106
|
self.keep_accents = keep_accents
|
|
110
107
|
|
|
111
108
|
if vocab is not None:
|
|
112
|
-
self._vocab_scores =
|
|
109
|
+
self._vocab_scores = vocab
|
|
113
110
|
else:
|
|
114
111
|
self._vocab_scores = [
|
|
115
112
|
(str(pad_token), 0.0),
|
|
@@ -163,10 +160,7 @@ class AlbertTokenizer(TokenizersBackend):
|
|
|
163
160
|
],
|
|
164
161
|
)
|
|
165
162
|
|
|
166
|
-
tokenizer_object = self._tokenizer
|
|
167
|
-
|
|
168
163
|
super().__init__(
|
|
169
|
-
tokenizer_object=tokenizer_object,
|
|
170
164
|
do_lower_case=self.do_lower_case,
|
|
171
165
|
keep_accents=self.keep_accents,
|
|
172
166
|
bos_token=bos_token,
|
|
@@ -1004,6 +1004,7 @@ class AlignVisionModel(AlignPreTrainedModel):
|
|
|
1004
1004
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
1005
1005
|
output_hidden_states: Optional[bool] = None,
|
|
1006
1006
|
return_dict: Optional[bool] = None,
|
|
1007
|
+
**kwargs,
|
|
1007
1008
|
) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
|
|
1008
1009
|
r"""
|
|
1009
1010
|
Examples:
|
|
@@ -1169,6 +1170,7 @@ class AlignModel(AlignPreTrainedModel):
|
|
|
1169
1170
|
output_attentions: Optional[bool] = None,
|
|
1170
1171
|
output_hidden_states: Optional[bool] = None,
|
|
1171
1172
|
return_dict: Optional[bool] = None,
|
|
1173
|
+
**kwargs,
|
|
1172
1174
|
) -> Union[tuple, AlignOutput]:
|
|
1173
1175
|
r"""
|
|
1174
1176
|
return_loss (`bool`, *optional*):
|
|
@@ -891,6 +891,7 @@ class AltCLIPVisionModel(AltCLIPPreTrainedModel):
|
|
|
891
891
|
output_hidden_states: Optional[bool] = None,
|
|
892
892
|
interpolate_pos_encoding: bool = False,
|
|
893
893
|
return_dict: Optional[bool] = None,
|
|
894
|
+
**kwargs,
|
|
894
895
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
|
895
896
|
r"""
|
|
896
897
|
Examples:
|
|
@@ -970,6 +971,7 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
|
|
970
971
|
output_attentions: Optional[bool] = None,
|
|
971
972
|
output_hidden_states: Optional[bool] = None,
|
|
972
973
|
return_dict: Optional[bool] = None,
|
|
974
|
+
**kwargs,
|
|
973
975
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
|
974
976
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
975
977
|
output_hidden_states = (
|
|
@@ -1061,6 +1063,7 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
|
|
|
1061
1063
|
output_attentions: Optional[bool] = None,
|
|
1062
1064
|
return_dict: Optional[bool] = None,
|
|
1063
1065
|
output_hidden_states: Optional[bool] = None,
|
|
1066
|
+
**kwargs,
|
|
1064
1067
|
) -> Union[tuple, BaseModelOutputWithPoolingAndProjection]:
|
|
1065
1068
|
r"""
|
|
1066
1069
|
Examples:
|
|
@@ -1236,6 +1239,7 @@ class AltCLIPModel(AltCLIPPreTrainedModel):
|
|
|
1236
1239
|
output_hidden_states: Optional[bool] = None,
|
|
1237
1240
|
interpolate_pos_encoding: bool = False,
|
|
1238
1241
|
return_dict: Optional[bool] = None,
|
|
1242
|
+
**kwargs,
|
|
1239
1243
|
) -> Union[tuple, AltCLIPOutput]:
|
|
1240
1244
|
r"""
|
|
1241
1245
|
return_loss (`bool`, *optional*):
|