transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -20,18 +20,18 @@ import copy
|
|
|
20
20
|
import json
|
|
21
21
|
import os
|
|
22
22
|
from collections import defaultdict
|
|
23
|
+
from collections.abc import Iterable
|
|
23
24
|
from shutil import copyfile
|
|
24
25
|
from typing import Any, Optional, Union
|
|
25
26
|
|
|
26
27
|
import tokenizers.pre_tokenizers as pre_tokenizers_fast
|
|
28
|
+
from huggingface_hub import is_offline_mode
|
|
27
29
|
from tokenizers import AddedToken, processors
|
|
28
30
|
from tokenizers import Encoding as EncodingFast
|
|
29
31
|
from tokenizers import Tokenizer as TokenizerFast
|
|
30
|
-
from tokenizers import normalizers as tokenizers_normalizers
|
|
31
32
|
from tokenizers.decoders import Decoder as DecoderFast
|
|
32
33
|
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
|
|
33
34
|
|
|
34
|
-
from .convert_slow_tokenizer import convert_slow_tokenizer
|
|
35
35
|
from .integrations.ggml import convert_gguf_tokenizer
|
|
36
36
|
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
|
|
37
37
|
from .tokenization_utils_base import (
|
|
@@ -41,8 +41,9 @@ from .tokenization_utils_base import (
|
|
|
41
41
|
PreTrainedTokenizerBase,
|
|
42
42
|
TextInput,
|
|
43
43
|
TruncationStrategy,
|
|
44
|
+
generate_merges,
|
|
44
45
|
)
|
|
45
|
-
from .utils import PaddingStrategy, add_end_docstrings,
|
|
46
|
+
from .utils import PaddingStrategy, add_end_docstrings, logging
|
|
46
47
|
|
|
47
48
|
|
|
48
49
|
logger = logging.get_logger(__name__)
|
|
@@ -90,26 +91,157 @@ class TokenizersBackend(PreTrainedTokenizerBase):
|
|
|
90
91
|
"""
|
|
91
92
|
|
|
92
93
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
94
|
+
model = None
|
|
95
|
+
_tokenizer = None
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def convert_to_native_format(cls, trust_remote_code=False, **kwargs):
|
|
99
|
+
"""s
|
|
100
|
+
Build a `tokenizers.Tokenizer` backend from the available serialization files (tokenizer.json, sentencepiece
|
|
101
|
+
models, tekken.json, vocab/merges).
|
|
102
|
+
"""
|
|
103
|
+
# Preserve kwargs for possible downstream use
|
|
104
|
+
local_kwargs = dict(kwargs)
|
|
105
|
+
fast_tokenizer_file = local_kwargs.pop("tokenizer_file", None)
|
|
106
|
+
|
|
107
|
+
if (
|
|
108
|
+
fast_tokenizer_file is not None
|
|
109
|
+
and os.path.isfile(fast_tokenizer_file)
|
|
110
|
+
and (cls is TokenizersBackend or "__init__" not in cls.__dict__ or trust_remote_code)
|
|
111
|
+
):
|
|
112
|
+
local_kwargs["tokenizer_object"] = TokenizerFast.from_file(fast_tokenizer_file)
|
|
113
|
+
return local_kwargs
|
|
114
|
+
elif fast_tokenizer_file is not None and os.path.isfile(fast_tokenizer_file):
|
|
115
|
+
# we extract vocab / merges from the tokenizer file to pass them to __init__
|
|
116
|
+
processor = TokenizerFast.from_file(fast_tokenizer_file).post_processor
|
|
117
|
+
with open(fast_tokenizer_file, encoding="utf-8") as tokenizer_handle:
|
|
118
|
+
tokenizer_json = json.load(tokenizer_handle)
|
|
119
|
+
vocab = tokenizer_json.get("model", {}).get("vocab", None)
|
|
120
|
+
if cls.model is None:
|
|
121
|
+
if isinstance(vocab, list):
|
|
122
|
+
vocab = list(map(tuple, vocab)) # TODO just for now
|
|
123
|
+
elif cls.model.__name__ == "Unigram":
|
|
124
|
+
vocab = list(map(tuple, vocab))
|
|
125
|
+
elif cls.model.__name__ == "WordLevel":
|
|
126
|
+
vocab = {token: i for i, token in enumerate(vocab)}
|
|
127
|
+
elif cls.model.__name__ == "BPE" or cls.model.__name__ == "WordPiece":
|
|
128
|
+
if isinstance(vocab, list):
|
|
129
|
+
vocab = {token[0] if isinstance(token, list) else token: i for i, token in enumerate(vocab)}
|
|
130
|
+
local_kwargs["vocab"] = vocab
|
|
131
|
+
|
|
132
|
+
model_type = getattr(cls, "model", None)
|
|
133
|
+
if "merges" in tokenizer_json.get("model", {}) and (model_type and model_type.__name__ == "BPE"):
|
|
134
|
+
merges = tokenizer_json["model"]["merges"]
|
|
135
|
+
merges = [tuple(merge.split(" ")) if isinstance(merge, str) else tuple(merge) for merge in merges]
|
|
136
|
+
local_kwargs["merges"] = merges
|
|
137
|
+
|
|
138
|
+
if processor is not None:
|
|
139
|
+
local_kwargs["post_processor"] = processor
|
|
140
|
+
return local_kwargs
|
|
141
|
+
|
|
142
|
+
vocab_file = local_kwargs.get("vocab_file")
|
|
143
|
+
merges_file = local_kwargs.get("merges_file")
|
|
144
|
+
vocab = local_kwargs.get("vocab")
|
|
145
|
+
merges = local_kwargs.get("merges")
|
|
146
|
+
|
|
147
|
+
# Tekken converter (Mistral)
|
|
148
|
+
if isinstance(vocab_file, str) and vocab_file.endswith("tekken.json") and os.path.isfile(vocab_file):
|
|
149
|
+
from .convert_slow_tokenizer import MistralConverter
|
|
150
|
+
|
|
151
|
+
local_kwargs["vocab"], local_kwargs["merges"] = MistralConverter(
|
|
152
|
+
vocab_file=vocab_file
|
|
153
|
+
).extract_vocab_merges_from_model(vocab_file)
|
|
154
|
+
return local_kwargs
|
|
155
|
+
|
|
156
|
+
# SentencePiece model (with TikToken fallback)
|
|
157
|
+
if isinstance(vocab_file, str) and os.path.isfile(vocab_file) and vocab_file.endswith(".model"):
|
|
158
|
+
try:
|
|
159
|
+
from .convert_slow_tokenizer import SentencePieceExtractor
|
|
160
|
+
|
|
161
|
+
local_kwargs = SentencePieceExtractor(vocab_file).extract(cls.model, **local_kwargs)
|
|
162
|
+
try:
|
|
163
|
+
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
|
|
164
|
+
|
|
165
|
+
converter_class = SLOW_TO_FAST_CONVERTERS.get(cls.__name__)
|
|
166
|
+
if converter_class is not None and hasattr(converter_class, "convert_from_spm"):
|
|
167
|
+
local_kwargs = converter_class.convert_from_spm(**local_kwargs)
|
|
168
|
+
except Exception as e:
|
|
169
|
+
logger.warning(
|
|
170
|
+
f"Could not reorder vocab using converter for {cls.__name__} due to {e}. Falling back to raw SentencePiece extraction."
|
|
171
|
+
)
|
|
172
|
+
# what used to be in `convert_slow`
|
|
173
|
+
if hasattr(cls, "convert_from_spm_model"):
|
|
174
|
+
local_kwargs = cls.convert_from_spm_model(**local_kwargs)
|
|
175
|
+
except Exception as e: # TODO only catch deserialization error here!
|
|
176
|
+
logger.warning(
|
|
177
|
+
f"Could not extract SentencePiece model from {vocab_file} using sentencepiece library due to {e}. "
|
|
178
|
+
"Falling back to TikToken extractor."
|
|
179
|
+
)
|
|
180
|
+
from .convert_slow_tokenizer import TikTokenConverter
|
|
181
|
+
|
|
182
|
+
local_kwargs["vocab"], local_kwargs["merges"] = TikTokenConverter(
|
|
183
|
+
vocab_file=vocab_file, extra_special_tokens=local_kwargs.get("extra_special_tokens")
|
|
184
|
+
).extract_vocab_merges_from_model(vocab_file)
|
|
185
|
+
return local_kwargs
|
|
186
|
+
|
|
187
|
+
# Fallback to standard vocab/merges files if they existed!
|
|
188
|
+
if vocab is None and isinstance(vocab_file, str) and os.path.isfile(vocab_file):
|
|
189
|
+
local_kwargs["vocab"] = vocab_file
|
|
190
|
+
vocab = local_kwargs["vocab"]
|
|
191
|
+
if merges is None and isinstance(merges_file, str) and os.path.isfile(merges_file):
|
|
192
|
+
local_kwargs["merges"] = merges_file
|
|
193
|
+
merges = local_kwargs["merges"]
|
|
194
|
+
|
|
195
|
+
# Generate merges automatically when not provided for BPE tokenizers
|
|
196
|
+
if merges is None and cls.model is not None and cls.model.__name__ == "BPE" and isinstance(vocab, dict):
|
|
197
|
+
# Gather special tokens from kwargs to skip in merge generation
|
|
198
|
+
def _iter_special_tokens(values: Iterable[Any]) -> list[str]:
|
|
199
|
+
collected: list[str] = []
|
|
200
|
+
for val in values:
|
|
201
|
+
if val is None:
|
|
202
|
+
continue
|
|
203
|
+
if isinstance(val, (list, tuple)):
|
|
204
|
+
collected.extend(_iter_special_tokens(val))
|
|
205
|
+
else:
|
|
206
|
+
collected.append(str(val))
|
|
207
|
+
return collected
|
|
208
|
+
|
|
209
|
+
special_tokens_keys = [
|
|
210
|
+
"pad_token",
|
|
211
|
+
"unk_token",
|
|
212
|
+
"bos_token",
|
|
213
|
+
"eos_token",
|
|
214
|
+
"sep_token",
|
|
215
|
+
"cls_token",
|
|
216
|
+
"mask_token",
|
|
217
|
+
"additional_special_tokens",
|
|
218
|
+
"extra_special_tokens",
|
|
219
|
+
]
|
|
220
|
+
skip_tokens: set[str] = set()
|
|
221
|
+
for key in special_tokens_keys:
|
|
222
|
+
if key in local_kwargs:
|
|
223
|
+
skip_tokens.update(_iter_special_tokens([local_kwargs[key]]))
|
|
224
|
+
|
|
225
|
+
merges = generate_merges(vocab, skip_tokens=skip_tokens)
|
|
226
|
+
local_kwargs["merges"] = merges
|
|
227
|
+
return local_kwargs
|
|
93
228
|
|
|
94
229
|
def __init__(self, *args, **kwargs):
|
|
95
230
|
tokenizer_object = kwargs.pop("tokenizer_object", None)
|
|
96
|
-
slow_tokenizer = kwargs.pop("__slow_tokenizer", None)
|
|
97
231
|
gguf_file = kwargs.pop("gguf_file", None)
|
|
98
232
|
fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
|
|
99
|
-
from_slow = kwargs.pop("from_slow", False)
|
|
100
233
|
# Note: added_tokens_decoder is NOT popped - it's passed to super().__init__() for processing
|
|
101
234
|
added_tokens_decoder = kwargs.get("added_tokens_decoder", {})
|
|
102
235
|
# Store add_prefix_space before super().__init__() to ensure it's not overridden
|
|
103
236
|
add_prefix_space = kwargs.get("add_prefix_space", False)
|
|
237
|
+
vocab_file = kwargs.get("vocab_file")
|
|
104
238
|
|
|
239
|
+
fast_tokenizer = None
|
|
105
240
|
if tokenizer_object is not None:
|
|
106
241
|
fast_tokenizer = copy.deepcopy(tokenizer_object)
|
|
107
|
-
elif fast_tokenizer_file is not None and
|
|
242
|
+
elif fast_tokenizer_file is not None and os.path.isfile(fast_tokenizer_file):
|
|
108
243
|
# We have a serialization from tokenizers which let us directly build the backend
|
|
109
244
|
fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
|
|
110
|
-
elif slow_tokenizer:
|
|
111
|
-
# We need to convert a slow tokenizer to build the backend
|
|
112
|
-
fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
|
|
113
245
|
elif gguf_file is not None:
|
|
114
246
|
# We need to convert a slow tokenizer to build the backend
|
|
115
247
|
gguf_param = load_gguf_checkpoint(kwargs.get("vocab_file"))
|
|
@@ -120,18 +252,7 @@ class TokenizersBackend(PreTrainedTokenizerBase):
|
|
|
120
252
|
kwargs.update(tokenizer_config)
|
|
121
253
|
if len(additional_kwargs) > 0:
|
|
122
254
|
kwargs.update(additional_kwargs)
|
|
123
|
-
elif self.
|
|
124
|
-
# We need to create and convert a slow tokenizer to build the backend
|
|
125
|
-
slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
|
|
126
|
-
fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
|
|
127
|
-
elif not slow_tokenizer:
|
|
128
|
-
# We tried loading a slow_tokenizer with spm and failed, try to load with tiktoken
|
|
129
|
-
self.vocab_file = kwargs.get("vocab_file")
|
|
130
|
-
# V5: Set _extra_special_tokens directly for converter
|
|
131
|
-
self._extra_special_tokens = kwargs.get("extra_special_tokens", [])
|
|
132
|
-
fast_tokenizer = convert_slow_tokenizer(self, from_tiktoken=True)
|
|
133
|
-
slow_tokenizer = None
|
|
134
|
-
else:
|
|
255
|
+
elif self._tokenizer is None:
|
|
135
256
|
raise ValueError(
|
|
136
257
|
"Couldn't instantiate the backend tokenizer from one of: \n"
|
|
137
258
|
"(1) a `tokenizers` library serialization file, \n"
|
|
@@ -139,11 +260,11 @@ class TokenizersBackend(PreTrainedTokenizerBase):
|
|
|
139
260
|
"(3) an equivalent slow tokenizer class to instantiate and convert. \n"
|
|
140
261
|
"You need to have sentencepiece or tiktoken installed to convert a slow tokenizer to a fast one."
|
|
141
262
|
)
|
|
263
|
+
if fast_tokenizer is not None:
|
|
264
|
+
self._tokenizer = fast_tokenizer
|
|
142
265
|
|
|
143
|
-
self._tokenizer
|
|
144
|
-
|
|
145
|
-
if slow_tokenizer is not None:
|
|
146
|
-
kwargs.update(slow_tokenizer.init_kwargs)
|
|
266
|
+
if self._tokenizer is None:
|
|
267
|
+
raise ValueError("The backend tokenizer is not correctly initialized.")
|
|
147
268
|
|
|
148
269
|
_truncation = self._tokenizer.truncation
|
|
149
270
|
|
|
@@ -168,9 +289,17 @@ class TokenizersBackend(PreTrainedTokenizerBase):
|
|
|
168
289
|
# Set backend to "tokenizers" if not already set
|
|
169
290
|
if "backend" not in kwargs:
|
|
170
291
|
kwargs["backend"] = "tokenizers"
|
|
171
|
-
|
|
292
|
+
explicit_bos_eos_in_kwargs = "add_bos_token" in kwargs or "add_eos_token" in kwargs
|
|
293
|
+
self._add_bos_token = kwargs.get("add_bos_token", False)
|
|
294
|
+
self._add_eos_token = kwargs.get("add_eos_token", False)
|
|
295
|
+
if post_processor := kwargs.pop("post_processor", None): # most reliable way to get the post-processor
|
|
296
|
+
self._tokenizer.post_processor = post_processor
|
|
297
|
+
self._should_update_post_processor = explicit_bos_eos_in_kwargs or self._tokenizer.post_processor is None
|
|
172
298
|
# We call this after having initialized the backend tokenizer because we update it.
|
|
173
299
|
super().__init__(**kwargs)
|
|
300
|
+
|
|
301
|
+
if vocab_file is not None:
|
|
302
|
+
self.vocab_file = vocab_file
|
|
174
303
|
# Ensure add_prefix_space is set correctly after parent init
|
|
175
304
|
self.add_prefix_space = add_prefix_space
|
|
176
305
|
self._tokenizer.encode_special_tokens = self.split_special_tokens
|
|
@@ -228,6 +357,12 @@ class TokenizersBackend(PreTrainedTokenizerBase):
|
|
|
228
357
|
**kwargs,
|
|
229
358
|
)
|
|
230
359
|
|
|
360
|
+
self._should_update_post_processor = (
|
|
361
|
+
self._should_update_post_processor or self._tokenizer.post_processor is None
|
|
362
|
+
)
|
|
363
|
+
if self._should_update_post_processor:
|
|
364
|
+
self.update_post_processor()
|
|
365
|
+
|
|
231
366
|
@property
|
|
232
367
|
def is_fast(self) -> bool:
|
|
233
368
|
return True
|
|
@@ -273,7 +408,7 @@ class TokenizersBackend(PreTrainedTokenizerBase):
|
|
|
273
408
|
# If eos_token is None and add_eos_token is True, silently disable add_eos_token
|
|
274
409
|
# This allows tokenizers to set add_eos_token even if eos_token is not configured
|
|
275
410
|
if eos is None and self.add_eos_token:
|
|
276
|
-
self.
|
|
411
|
+
self.add_eos_token = False
|
|
277
412
|
return
|
|
278
413
|
|
|
279
414
|
single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
|
|
@@ -320,98 +455,24 @@ class TokenizersBackend(PreTrainedTokenizerBase):
|
|
|
320
455
|
if token_value is None:
|
|
321
456
|
continue
|
|
322
457
|
if isinstance(token_value, AddedToken):
|
|
323
|
-
|
|
324
|
-
tokens_to_add.append(token_value)
|
|
458
|
+
tokens_to_add.append(token_value)
|
|
325
459
|
elif isinstance(token_value, str):
|
|
326
|
-
|
|
327
|
-
tokens_to_add.append(AddedToken(token_value, special=True, normalized=False))
|
|
460
|
+
tokens_to_add.append(AddedToken(token_value, special=True, normalized=False))
|
|
328
461
|
|
|
329
462
|
# V5: Check extra special tokens
|
|
330
463
|
for token in self._extra_special_tokens:
|
|
331
464
|
if isinstance(token, AddedToken):
|
|
332
|
-
|
|
333
|
-
tokens_to_add.append(token)
|
|
465
|
+
tokens_to_add.append(token)
|
|
334
466
|
elif isinstance(token, str):
|
|
335
|
-
|
|
336
|
-
tokens_to_add.append(AddedToken(token, special=True, normalized=False))
|
|
467
|
+
tokens_to_add.append(AddedToken(token, special=True, normalized=False))
|
|
337
468
|
|
|
338
469
|
if tokens_to_add:
|
|
339
470
|
# Ensure special tokens are added as such to the backend
|
|
340
471
|
self.add_tokens(tokens_to_add, special_tokens=True)
|
|
341
472
|
|
|
342
|
-
if
|
|
473
|
+
if getattr(self, "_should_update_post_processor", True) or self._tokenizer.post_processor is None:
|
|
343
474
|
self.update_post_processor()
|
|
344
475
|
|
|
345
|
-
# Update add_prefix_space in the pre_tokenizer if needed
|
|
346
|
-
if hasattr(self, "add_prefix_space"):
|
|
347
|
-
try:
|
|
348
|
-
tokenizer_json = json.loads(self.backend_tokenizer.to_str())
|
|
349
|
-
pre_tok = tokenizer_json.get("pre_tokenizer", {})
|
|
350
|
-
|
|
351
|
-
# Recursively update add_prefix_space in pretokenizers
|
|
352
|
-
def update_add_prefix_space(pretok_dict, value):
|
|
353
|
-
updated = False
|
|
354
|
-
if pretok_dict.get("type") == "Sequence":
|
|
355
|
-
for nested in pretok_dict.get("pretokenizers", []):
|
|
356
|
-
updated |= update_add_prefix_space(nested, value)
|
|
357
|
-
elif "add_prefix_space" in pretok_dict and pretok_dict["add_prefix_space"] != value:
|
|
358
|
-
pretok_dict["add_prefix_space"] = value
|
|
359
|
-
updated = True
|
|
360
|
-
return updated
|
|
361
|
-
|
|
362
|
-
if update_add_prefix_space(pre_tok, self.add_prefix_space):
|
|
363
|
-
self._tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json))
|
|
364
|
-
except Exception:
|
|
365
|
-
pass
|
|
366
|
-
|
|
367
|
-
# Ensure normalizer flags (lowercase/accents/chinese chars) reflect tokenizer attributes
|
|
368
|
-
try:
|
|
369
|
-
normalizer = self.backend_tokenizer.normalizer
|
|
370
|
-
if normalizer is not None:
|
|
371
|
-
norm_state = json.loads(normalizer.__getstate__())
|
|
372
|
-
norm_type = norm_state.get("type")
|
|
373
|
-
|
|
374
|
-
desired_lowercase = getattr(self, "do_lower_case", None)
|
|
375
|
-
desired_strip_accents = getattr(self, "strip_accents", None)
|
|
376
|
-
# Some tokenizers expose keep_accents instead of strip_accents
|
|
377
|
-
if desired_strip_accents is None and hasattr(self, "keep_accents") and "strip_accents" in norm_state:
|
|
378
|
-
keep_accents_value = getattr(self, "keep_accents")
|
|
379
|
-
if keep_accents_value is not None:
|
|
380
|
-
desired_strip_accents = not keep_accents_value
|
|
381
|
-
desired_handle_chinese = getattr(self, "tokenize_chinese_chars", None)
|
|
382
|
-
|
|
383
|
-
updated = False
|
|
384
|
-
if (
|
|
385
|
-
desired_lowercase is not None
|
|
386
|
-
and "lowercase" in norm_state
|
|
387
|
-
and norm_state["lowercase"] != desired_lowercase
|
|
388
|
-
):
|
|
389
|
-
norm_state["lowercase"] = desired_lowercase
|
|
390
|
-
updated = True
|
|
391
|
-
if (
|
|
392
|
-
desired_strip_accents is not None
|
|
393
|
-
and "strip_accents" in norm_state
|
|
394
|
-
and norm_state["strip_accents"] != desired_strip_accents
|
|
395
|
-
):
|
|
396
|
-
norm_state["strip_accents"] = desired_strip_accents
|
|
397
|
-
updated = True
|
|
398
|
-
if (
|
|
399
|
-
desired_handle_chinese is not None
|
|
400
|
-
and "handle_chinese_chars" in norm_state
|
|
401
|
-
and norm_state["handle_chinese_chars"] != desired_handle_chinese
|
|
402
|
-
):
|
|
403
|
-
norm_state["handle_chinese_chars"] = desired_handle_chinese
|
|
404
|
-
updated = True
|
|
405
|
-
|
|
406
|
-
if updated and norm_type is not None:
|
|
407
|
-
norm_class = getattr(tokenizers_normalizers, norm_type, None)
|
|
408
|
-
if norm_class is not None:
|
|
409
|
-
norm_state.pop("type", None)
|
|
410
|
-
self.backend_tokenizer.normalizer = norm_class(**norm_state)
|
|
411
|
-
except Exception:
|
|
412
|
-
# Best-effort: do not block initialization on normalizer reconciliation
|
|
413
|
-
pass
|
|
414
|
-
|
|
415
476
|
@property
|
|
416
477
|
def vocab_size(self) -> int:
|
|
417
478
|
"""
|
|
@@ -1132,7 +1193,7 @@ class TokenizersBackend(PreTrainedTokenizerBase):
|
|
|
1132
1193
|
]
|
|
1133
1194
|
):
|
|
1134
1195
|
return tokenizer
|
|
1135
|
-
elif transformers_version and version.parse(transformers_version)
|
|
1196
|
+
elif transformers_version and version.parse(transformers_version) > version.parse("4.57.3"):
|
|
1136
1197
|
return tokenizer
|
|
1137
1198
|
|
|
1138
1199
|
mistral_config_detected = True
|
transformers/trainer.py
CHANGED
|
@@ -642,6 +642,16 @@ class Trainer:
|
|
|
642
642
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
|
643
643
|
)
|
|
644
644
|
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
|
645
|
+
|
|
646
|
+
# Add JIT checkpoint callback if enabled
|
|
647
|
+
if self.args.enable_jit_checkpoint:
|
|
648
|
+
from .trainer_jit_checkpoint import JITCheckpointCallback
|
|
649
|
+
|
|
650
|
+
jit_callback = JITCheckpointCallback()
|
|
651
|
+
default_callbacks = default_callbacks + [jit_callback]
|
|
652
|
+
# Set trainer reference for JIT callback after initialization
|
|
653
|
+
jit_callback.set_trainer(self)
|
|
654
|
+
|
|
645
655
|
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
|
646
656
|
self.callback_handler = CallbackHandler(
|
|
647
657
|
callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
|
@@ -2338,6 +2348,8 @@ class Trainer:
|
|
|
2338
2348
|
|
|
2339
2349
|
if self.is_fsdp_enabled:
|
|
2340
2350
|
self.model = self.model_wrapped = model
|
|
2351
|
+
# Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA
|
|
2352
|
+
dist.fsdp.register_fsdp_forward_method(self.model, "generate")
|
|
2341
2353
|
|
|
2342
2354
|
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
|
2343
2355
|
if model is not self.model:
|
|
@@ -2428,8 +2440,6 @@ class Trainer:
|
|
|
2428
2440
|
|
|
2429
2441
|
for epoch in range(epochs_trained, num_train_epochs):
|
|
2430
2442
|
epoch_dataloader = train_dataloader
|
|
2431
|
-
if hasattr(epoch_dataloader, "set_epoch"):
|
|
2432
|
-
epoch_dataloader.set_epoch(epoch)
|
|
2433
2443
|
|
|
2434
2444
|
steps_in_epoch = (
|
|
2435
2445
|
len(epoch_dataloader)
|
|
@@ -2450,6 +2460,9 @@ class Trainer:
|
|
|
2450
2460
|
elif steps_trained_in_current_epoch == 0:
|
|
2451
2461
|
self._load_rng_state(resume_from_checkpoint)
|
|
2452
2462
|
|
|
2463
|
+
if hasattr(epoch_dataloader, "set_epoch"):
|
|
2464
|
+
epoch_dataloader.set_epoch(epoch)
|
|
2465
|
+
|
|
2453
2466
|
epoch_iterator = iter(epoch_dataloader)
|
|
2454
2467
|
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
|
|
2455
2468
|
remainder = steps_in_epoch % args.gradient_accumulation_steps
|
|
@@ -2788,7 +2801,7 @@ class Trainer:
|
|
|
2788
2801
|
)
|
|
2789
2802
|
else:
|
|
2790
2803
|
# We load the model state dict on the CPU to avoid an OOM error.
|
|
2791
|
-
if
|
|
2804
|
+
if os.path.isfile(safe_weights_file):
|
|
2792
2805
|
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
|
|
2793
2806
|
else:
|
|
2794
2807
|
check_torch_load_is_safe()
|
|
@@ -2828,9 +2841,7 @@ class Trainer:
|
|
|
2828
2841
|
logger.warning(f"Could not load adapter model, make sure to have PEFT >= {MIN_PEFT_VERSION} installed")
|
|
2829
2842
|
else:
|
|
2830
2843
|
# We load the sharded checkpoint
|
|
2831
|
-
load_result = load_sharded_checkpoint(
|
|
2832
|
-
model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
|
|
2833
|
-
)
|
|
2844
|
+
load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled())
|
|
2834
2845
|
if not is_sagemaker_mp_enabled():
|
|
2835
2846
|
self._issue_warnings_after_load(load_result)
|
|
2836
2847
|
|
|
@@ -2913,7 +2924,7 @@ class Trainer:
|
|
|
2913
2924
|
has_been_loaded = False
|
|
2914
2925
|
else:
|
|
2915
2926
|
# We load the model state dict on the CPU to avoid an OOM error.
|
|
2916
|
-
if
|
|
2927
|
+
if os.path.isfile(best_safe_model_path):
|
|
2917
2928
|
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
|
2918
2929
|
else:
|
|
2919
2930
|
check_torch_load_is_safe()
|
|
@@ -4067,12 +4078,7 @@ class Trainer:
|
|
|
4067
4078
|
model = model.module.module
|
|
4068
4079
|
unwrapped_model = self.accelerator.unwrap_model(model)
|
|
4069
4080
|
if isinstance(unwrapped_model, supported_classes):
|
|
4070
|
-
unwrapped_model.save_pretrained(
|
|
4071
|
-
output_dir,
|
|
4072
|
-
state_dict=full_state_dict,
|
|
4073
|
-
save_function=xm.save,
|
|
4074
|
-
safe_serialization=self.args.save_safetensors,
|
|
4075
|
-
)
|
|
4081
|
+
unwrapped_model.save_pretrained(output_dir, state_dict=full_state_dict)
|
|
4076
4082
|
else:
|
|
4077
4083
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
|
4078
4084
|
xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
|
@@ -4082,8 +4088,6 @@ class Trainer:
|
|
|
4082
4088
|
output_dir,
|
|
4083
4089
|
is_main_process=self.args.should_save,
|
|
4084
4090
|
state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
|
|
4085
|
-
save_function=xm.save,
|
|
4086
|
-
safe_serialization=self.args.save_safetensors,
|
|
4087
4091
|
)
|
|
4088
4092
|
else:
|
|
4089
4093
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
|
@@ -4093,8 +4097,6 @@ class Trainer:
|
|
|
4093
4097
|
model.save_pretrained(
|
|
4094
4098
|
output_dir,
|
|
4095
4099
|
is_main_process=self.args.should_save,
|
|
4096
|
-
save_function=xm.save,
|
|
4097
|
-
safe_serialization=self.args.save_safetensors,
|
|
4098
4100
|
state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
|
|
4099
4101
|
)
|
|
4100
4102
|
if self.processing_class is not None and self.args.should_save:
|
|
@@ -4115,20 +4117,15 @@ class Trainer:
|
|
|
4115
4117
|
|
|
4116
4118
|
if isinstance(self.accelerator.unwrap_model(self.model, keep_torch_compile=False), supported_classes):
|
|
4117
4119
|
self.accelerator.unwrap_model(self.model, keep_torch_compile=False).save_pretrained(
|
|
4118
|
-
output_dir, state_dict=state_dict
|
|
4120
|
+
output_dir, state_dict=state_dict
|
|
4119
4121
|
)
|
|
4120
4122
|
else:
|
|
4121
4123
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
|
4122
|
-
|
|
4123
|
-
|
|
4124
|
-
|
|
4125
|
-
)
|
|
4126
|
-
else:
|
|
4127
|
-
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
|
4124
|
+
safetensors.torch.save_file(
|
|
4125
|
+
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
|
|
4126
|
+
)
|
|
4128
4127
|
else:
|
|
4129
|
-
self.model.save_pretrained(
|
|
4130
|
-
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
|
4131
|
-
)
|
|
4128
|
+
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
|
4132
4129
|
|
|
4133
4130
|
if self.processing_class is not None:
|
|
4134
4131
|
self.processing_class.save_pretrained(output_dir)
|
|
@@ -5074,14 +5071,14 @@ class Trainer:
|
|
|
5074
5071
|
self.is_tp_enabled = False
|
|
5075
5072
|
if getattr(self.model, "tp_size", None) is not None and self.model.tp_size > 1:
|
|
5076
5073
|
self.is_tp_enabled = True
|
|
5077
|
-
if self.args.parallelism_config is
|
|
5078
|
-
if is_accelerate_available("1.
|
|
5079
|
-
if self.args.parallelism_config is
|
|
5074
|
+
if self.args.parallelism_config is None:
|
|
5075
|
+
if is_accelerate_available("1.12.0"):
|
|
5076
|
+
if self.args.parallelism_config is None:
|
|
5080
5077
|
from accelerate import ParallelismConfig
|
|
5081
5078
|
|
|
5082
5079
|
args["parallelism_config"] = ParallelismConfig(tp_size=self.model.tp_size)
|
|
5083
5080
|
else:
|
|
5084
|
-
raise ValueError("Requires accelerate>1.
|
|
5081
|
+
raise ValueError("Requires accelerate>1.12.0 to use Tensor Parallelism.")
|
|
5085
5082
|
|
|
5086
5083
|
if is_accelerate_available("1.2.0"):
|
|
5087
5084
|
# it we don't have the correct version, we will rely on env var instead that were set in TrainingArguments
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import signal
|
|
3
|
+
import threading
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from .trainer_callback import TrainerCallback
|
|
7
|
+
from .trainer_utils import PREFIX_CHECKPOINT_DIR
|
|
8
|
+
from .utils import logging
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
logger = logging.get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CheckpointManager:
|
|
15
|
+
def __init__(self, trainer, kill_wait: int = 3):
|
|
16
|
+
"""
|
|
17
|
+
Initialize the CheckpointManager for Just-In-Time checkpoint handling.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
trainer: The Trainer instance that will be used to save checkpoints when SIGTERM is received.
|
|
21
|
+
kill_wait (`int`, *optional*, defaults to 3): Grace period to distinguish between SIGTERM and SIGKILL.
|
|
22
|
+
"""
|
|
23
|
+
self.trainer = trainer
|
|
24
|
+
self.is_checkpoint_requested = False
|
|
25
|
+
self._original_sigterm_handler = None
|
|
26
|
+
self.kill_wait = kill_wait
|
|
27
|
+
|
|
28
|
+
def setup_signal_handler(self):
|
|
29
|
+
self._original_sigterm_handler = signal.signal(signal.SIGTERM, self._sigterm_handler)
|
|
30
|
+
logger.info("JIT checkpoint signal handler registered for SIGTERM")
|
|
31
|
+
|
|
32
|
+
def _sigterm_handler(self, signum, frame):
|
|
33
|
+
if self.is_checkpoint_requested:
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
logger.info(f"SIGTERM received, will request JIT checkpoint after {self.kill_wait}s")
|
|
37
|
+
threading.Timer(self.kill_wait, self._enable_checkpoint).start()
|
|
38
|
+
|
|
39
|
+
def _enable_checkpoint(self):
|
|
40
|
+
logger.info("Kill wait period elapsed, requesting checkpoint")
|
|
41
|
+
self.is_checkpoint_requested = True
|
|
42
|
+
|
|
43
|
+
def execute_jit_checkpoint(self):
|
|
44
|
+
try:
|
|
45
|
+
# Set checkpoint flag to False to avoid multiple checkpoints getting triggered by other callbacks
|
|
46
|
+
self.is_checkpoint_requested = False
|
|
47
|
+
|
|
48
|
+
logger.info("Starting JIT checkpointing...")
|
|
49
|
+
current_step = self.trainer.state.global_step
|
|
50
|
+
logger.info(f"Saving JIT checkpoint at step {current_step}")
|
|
51
|
+
|
|
52
|
+
output_dir = self.trainer._get_output_dir(trial=None)
|
|
53
|
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{current_step}"
|
|
54
|
+
checkpoint_path = os.path.join(output_dir, checkpoint_folder)
|
|
55
|
+
|
|
56
|
+
# Create checkpoint directory
|
|
57
|
+
os.makedirs(checkpoint_path, exist_ok=True)
|
|
58
|
+
|
|
59
|
+
# Create a sentinel file to indicate checkpointing is in progress
|
|
60
|
+
sentinel_file = os.path.join(output_dir, checkpoint_folder, "checkpoint-is-incomplete.txt")
|
|
61
|
+
with open(sentinel_file, "w") as f:
|
|
62
|
+
f.write(f"Checkpoint started at step {current_step} and in progress...")
|
|
63
|
+
logger.info(f"Created checkpoint progress sentinel marker file: {sentinel_file}")
|
|
64
|
+
|
|
65
|
+
# Invoke the trainer's checkpoint method directly
|
|
66
|
+
self.trainer._save_checkpoint(self.trainer.model, trial=None)
|
|
67
|
+
|
|
68
|
+
# Remove sentinel file upon successful checkpointing
|
|
69
|
+
if os.path.exists(sentinel_file):
|
|
70
|
+
os.remove(sentinel_file)
|
|
71
|
+
logger.info("Sentinel marker file removed")
|
|
72
|
+
|
|
73
|
+
logger.info("Immediate JIT checkpoint completed successfully")
|
|
74
|
+
|
|
75
|
+
except Exception as e:
|
|
76
|
+
logger.error(f"Failed to save JIT checkpoint: {e}")
|
|
77
|
+
raise
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class JITCheckpointCallback(TrainerCallback):
|
|
81
|
+
"""
|
|
82
|
+
Callback for Just-In-Time checkpointing on SIGTERM signals.
|
|
83
|
+
|
|
84
|
+
When SIGTERM is received, the checkpoint manager sets `is_checkpoint_requested=True`.
|
|
85
|
+
The callbacks detect this flag and set `control.should_training_stop=True`, which signals
|
|
86
|
+
the Trainer's training loop to exit gracefully after saving the checkpoint.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self):
|
|
90
|
+
self.trainer = None
|
|
91
|
+
self.jit_manager: Optional[CheckpointManager] = None
|
|
92
|
+
|
|
93
|
+
def set_trainer(self, trainer):
|
|
94
|
+
self.trainer = trainer
|
|
95
|
+
if trainer.args.enable_jit_checkpoint:
|
|
96
|
+
self.jit_manager = CheckpointManager(trainer=trainer)
|
|
97
|
+
self.jit_manager.setup_signal_handler()
|
|
98
|
+
logger.info("JIT checkpointing enabled")
|
|
99
|
+
|
|
100
|
+
def on_pre_optimizer_step(self, args, state, control, **kwargs):
|
|
101
|
+
if self.jit_manager and self.jit_manager.is_checkpoint_requested:
|
|
102
|
+
control.should_training_stop = True
|
|
103
|
+
self.jit_manager.execute_jit_checkpoint()
|
|
104
|
+
|
|
105
|
+
def on_step_begin(self, args, state, control, **kwargs):
|
|
106
|
+
if self.jit_manager and self.jit_manager.is_checkpoint_requested:
|
|
107
|
+
control.should_training_stop = True
|
|
108
|
+
self.jit_manager.execute_jit_checkpoint()
|
|
109
|
+
|
|
110
|
+
def on_step_end(self, args, state, control, **kwargs):
|
|
111
|
+
if self.jit_manager and self.jit_manager.is_checkpoint_requested:
|
|
112
|
+
control.should_save = False
|
|
113
|
+
control.should_training_stop = True
|
|
114
|
+
self.jit_manager.execute_jit_checkpoint()
|
|
115
|
+
|
|
116
|
+
def on_epoch_end(self, args, state, control, **kwargs):
|
|
117
|
+
if self.jit_manager and self.jit_manager.is_checkpoint_requested:
|
|
118
|
+
control.should_save = False
|
|
119
|
+
control.should_training_stop = True
|
|
120
|
+
self.jit_manager.execute_jit_checkpoint()
|
|
121
|
+
|
|
122
|
+
def on_train_end(self, args, state, control, **kwargs):
|
|
123
|
+
# Restore original SIGTERM handler
|
|
124
|
+
if self.jit_manager and self.jit_manager._original_sigterm_handler is not None:
|
|
125
|
+
signal.signal(signal.SIGTERM, self.jit_manager._original_sigterm_handler)
|
|
126
|
+
logger.info("Restored original SIGTERM handler after training completion")
|
transformers/trainer_utils.py
CHANGED
|
@@ -924,7 +924,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
|
|
924
924
|
shard_files = list(set(index["weight_map"].values()))
|
|
925
925
|
|
|
926
926
|
# If strict=True, error before loading any of the state dicts.
|
|
927
|
-
# TODO: Here, update the
|
|
927
|
+
# TODO: Here, update the weight map with the config.dynamic_weight_conversion
|
|
928
928
|
loaded_keys = index["weight_map"].keys()
|
|
929
929
|
model_keys = model.state_dict().keys()
|
|
930
930
|
missing_keys = [key for key in model_keys if key not in loaded_keys]
|