transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
transformers/generation/utils.py
CHANGED
|
@@ -19,12 +19,12 @@ import inspect
|
|
|
19
19
|
import os
|
|
20
20
|
import warnings
|
|
21
21
|
from collections.abc import Callable
|
|
22
|
+
from contextlib import contextmanager
|
|
22
23
|
from dataclasses import dataclass
|
|
23
24
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
24
25
|
|
|
25
26
|
import torch
|
|
26
27
|
import torch.distributed as dist
|
|
27
|
-
from packaging import version
|
|
28
28
|
from torch import nn
|
|
29
29
|
|
|
30
30
|
from ..cache_utils import (
|
|
@@ -407,6 +407,9 @@ class GenerationMixin(ContinuousMixin):
|
|
|
407
407
|
**repo_loading_kwargs,
|
|
408
408
|
)
|
|
409
409
|
except OSError:
|
|
410
|
+
# `self` already has a generation config created from model config, but model config will
|
|
411
|
+
# not contain any generation-specific params. These are popped at config's `__init__`.
|
|
412
|
+
# Thus we have to load from `config.json` and create a generation config from it (for BART)
|
|
410
413
|
logger.info(
|
|
411
414
|
"Generation config file not found, using a generation config created from the model config."
|
|
412
415
|
)
|
|
@@ -418,6 +421,7 @@ class GenerationMixin(ContinuousMixin):
|
|
|
418
421
|
_from_model_config=True,
|
|
419
422
|
**repo_loading_kwargs,
|
|
420
423
|
)
|
|
424
|
+
|
|
421
425
|
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
|
|
422
426
|
if hasattr(self, "load_custom_generate") and trust_remote_code:
|
|
423
427
|
try:
|
|
@@ -593,6 +597,7 @@ class GenerationMixin(ContinuousMixin):
|
|
|
593
597
|
attention_mask: torch.LongTensor | None = None,
|
|
594
598
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
595
599
|
cache_position: torch.LongTensor | None = None,
|
|
600
|
+
is_first_iteration: bool | None = False,
|
|
596
601
|
**kwargs,
|
|
597
602
|
):
|
|
598
603
|
"""
|
|
@@ -628,7 +633,7 @@ class GenerationMixin(ContinuousMixin):
|
|
|
628
633
|
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
|
629
634
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt.
|
|
630
635
|
if not self.config.is_encoder_decoder:
|
|
631
|
-
if inputs_embeds is not None and
|
|
636
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
632
637
|
model_inputs[input_ids_key] = None
|
|
633
638
|
model_inputs["inputs_embeds"] = inputs_embeds
|
|
634
639
|
else:
|
|
@@ -708,6 +713,7 @@ class GenerationMixin(ContinuousMixin):
|
|
|
708
713
|
past_key_values=past_key_values,
|
|
709
714
|
position_ids=position_ids,
|
|
710
715
|
token_type_ids=token_type_ids,
|
|
716
|
+
is_first_iteration=is_first_iteration,
|
|
711
717
|
)
|
|
712
718
|
else:
|
|
713
719
|
attention_mask = causal_mask_creation_function(
|
|
@@ -1300,7 +1306,7 @@ class GenerationMixin(ContinuousMixin):
|
|
|
1300
1306
|
if generation_config.do_sample:
|
|
1301
1307
|
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
|
1302
1308
|
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
|
|
1303
|
-
if generation_config.num_beams > 1:
|
|
1309
|
+
if generation_config.num_beams is not None and generation_config.num_beams > 1:
|
|
1304
1310
|
if isinstance(generation_config._eos_token_tensor, list):
|
|
1305
1311
|
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
|
|
1306
1312
|
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
|
|
@@ -1722,8 +1728,8 @@ class GenerationMixin(ContinuousMixin):
|
|
|
1722
1728
|
)
|
|
1723
1729
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
|
1724
1730
|
|
|
1725
|
-
#
|
|
1726
|
-
#
|
|
1731
|
+
# If both `inputs_embeds` and `input_ids` are passed, we correct length with `inputs_tensor.shape`
|
|
1732
|
+
# We need to get max_length = inputs_embeds_len + max_new_tokens
|
|
1727
1733
|
elif (
|
|
1728
1734
|
model_input_name == "inputs_embeds"
|
|
1729
1735
|
and input_ids_length != inputs_tensor.shape[1]
|
|
@@ -1731,11 +1737,10 @@ class GenerationMixin(ContinuousMixin):
|
|
|
1731
1737
|
):
|
|
1732
1738
|
generation_config.max_length -= inputs_tensor.shape[1]
|
|
1733
1739
|
elif has_default_max_length: # by default let's always generate 20 new tokens
|
|
1734
|
-
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
|
|
1740
|
+
generation_config.max_length = generation_config.max_length + input_ids_length
|
|
1741
|
+
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
|
1742
|
+
if max_position_embeddings is not None:
|
|
1743
|
+
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
|
|
1739
1744
|
|
|
1740
1745
|
# same for min length
|
|
1741
1746
|
if generation_config.min_new_tokens is not None:
|
|
@@ -1760,7 +1765,6 @@ class GenerationMixin(ContinuousMixin):
|
|
|
1760
1765
|
def _prepare_generation_config(
|
|
1761
1766
|
self,
|
|
1762
1767
|
generation_config: GenerationConfig | None,
|
|
1763
|
-
use_model_defaults: bool | None = None,
|
|
1764
1768
|
**kwargs: Any,
|
|
1765
1769
|
) -> tuple[GenerationConfig, dict]:
|
|
1766
1770
|
"""
|
|
@@ -1768,93 +1772,57 @@ class GenerationMixin(ContinuousMixin):
|
|
|
1768
1772
|
function handles retrocompatibility with respect to configuration files.
|
|
1769
1773
|
"""
|
|
1770
1774
|
# parameterization priority:
|
|
1771
|
-
#
|
|
1775
|
+
# user-defined kwargs or `generation_config` > `self.generation_config` > global default values
|
|
1776
|
+
# TODO: (raushan) doesn't make sense to allow kwargs and `generation_config`. Should be mutually exclusive!
|
|
1772
1777
|
# TODO (joao): per-model generation config classes.
|
|
1773
1778
|
|
|
1774
|
-
using_model_generation_config = False
|
|
1775
1779
|
if generation_config is None:
|
|
1776
|
-
#
|
|
1777
|
-
|
|
1778
|
-
|
|
1779
|
-
|
|
1780
|
-
|
|
1781
|
-
|
|
1782
|
-
|
|
1783
|
-
|
|
1784
|
-
|
|
1785
|
-
and len(self.config._get_non_default_generation_parameters()) > 0 # 3)
|
|
1786
|
-
):
|
|
1787
|
-
new_generation_config = GenerationConfig.from_model_config(self.config)
|
|
1788
|
-
if new_generation_config != self.generation_config: # 4)
|
|
1789
|
-
raise ValueError(
|
|
1790
|
-
"You have modified the pretrained model configuration to control generation."
|
|
1791
|
-
" This strategy to control generation is not supported anymore. "
|
|
1792
|
-
" Please use and modify the model generation configuration (see"
|
|
1793
|
-
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
|
|
1794
|
-
)
|
|
1795
|
-
|
|
1796
|
-
generation_config = self.generation_config
|
|
1797
|
-
using_model_generation_config = True
|
|
1798
|
-
|
|
1799
|
-
# Related to #40039: prior to this PR, models with sliding window attention were forced to have
|
|
1800
|
-
# `cache_implementation="hybrid"` (the static sliding window cache). For these models, we now want to use
|
|
1801
|
-
# the dynamic sliding window cache by default, so we UNSET `cache_implementation` if it is a default value.
|
|
1802
|
-
# (if we're inside this branch, then it is because we're using default values from the Hub)
|
|
1803
|
-
if generation_config.cache_implementation == "hybrid":
|
|
1804
|
-
generation_config.cache_implementation = None
|
|
1780
|
+
# Users may modify `model.config` to control generation. This is a legacy behavior and is not supported anymore
|
|
1781
|
+
if len(self.config._get_generation_parameters()) > 0:
|
|
1782
|
+
raise ValueError(
|
|
1783
|
+
"You have modified the pretrained model configuration to control generation "
|
|
1784
|
+
f"We detected the following values set - {self.config._get_generation_parameters()}. "
|
|
1785
|
+
"This strategy to control generation is not supported anymore. Please use and modify `model.generation_config` "
|
|
1786
|
+
"(see https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
|
|
1787
|
+
)
|
|
1788
|
+
generation_config = GenerationConfig()
|
|
1805
1789
|
|
|
1806
1790
|
# `torch.export.export` usually raises an exception if it is called
|
|
1807
1791
|
# with ``strict=True``. deepcopy can only be processed if ``strict=False``.
|
|
1808
1792
|
generation_config = copy.deepcopy(generation_config)
|
|
1809
1793
|
|
|
1810
|
-
|
|
1811
|
-
|
|
1812
|
-
|
|
1813
|
-
|
|
1814
|
-
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
|
|
1828
|
-
|
|
1829
|
-
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
and model_gen_config_value != global_default_value
|
|
1833
|
-
):
|
|
1834
|
-
modified_values[key] = model_gen_config_value
|
|
1835
|
-
setattr(generation_config, key, model_gen_config_value)
|
|
1836
|
-
# edge case: we may set `temperature=0.0` and `do_sample=False`, but the model defaults to
|
|
1837
|
-
# `do_sample=True`
|
|
1838
|
-
if generation_config.temperature == 0.0:
|
|
1839
|
-
generation_config.do_sample = False
|
|
1840
|
-
if use_model_defaults is None and len(modified_values) > 0:
|
|
1841
|
-
logger.warning_once(
|
|
1842
|
-
f"`generation_config` default values have been modified to match model-specific defaults: "
|
|
1843
|
-
f"{modified_values}. If this is not desired, please set these values explicitly."
|
|
1844
|
-
)
|
|
1845
|
-
else:
|
|
1846
|
-
if generation_config.bos_token_id is None:
|
|
1847
|
-
generation_config.bos_token_id = self.generation_config.bos_token_id
|
|
1848
|
-
if generation_config.eos_token_id is None:
|
|
1849
|
-
generation_config.eos_token_id = self.generation_config.eos_token_id
|
|
1850
|
-
if generation_config.pad_token_id is None:
|
|
1851
|
-
generation_config.pad_token_id = self.generation_config.pad_token_id
|
|
1852
|
-
if generation_config.decoder_start_token_id is None:
|
|
1853
|
-
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
|
|
1854
|
-
|
|
1855
|
-
# Finally, apply any passed kwargs
|
|
1794
|
+
# First set values from the loaded `self.generation_config`, then set default values (BC)
|
|
1795
|
+
# Do not update any values that aren't `None`, i.e. if set by users explicitly and passed
|
|
1796
|
+
# to `generate()`. Thus the `defaults_only=True` is used
|
|
1797
|
+
global_defaults = self.generation_config._get_default_generation_params()
|
|
1798
|
+
generation_config.update(**self.generation_config.to_dict(), defaults_only=True)
|
|
1799
|
+
generation_config.update(**global_defaults, defaults_only=True)
|
|
1800
|
+
|
|
1801
|
+
# Due to some values being boolean and not `None`, we need additional logic to overwrite
|
|
1802
|
+
# them explicitly (`defaults_only=False`) on the condition that it's only a previous
|
|
1803
|
+
# default value
|
|
1804
|
+
default_generation_config = GenerationConfig()
|
|
1805
|
+
generation_config.update(
|
|
1806
|
+
**{
|
|
1807
|
+
k: v
|
|
1808
|
+
for k, v in self.generation_config.to_dict().items()
|
|
1809
|
+
if isinstance(v, bool)
|
|
1810
|
+
and hasattr(default_generation_config, k)
|
|
1811
|
+
and getattr(generation_config, k, None) == getattr(default_generation_config, k)
|
|
1812
|
+
}
|
|
1813
|
+
)
|
|
1814
|
+
|
|
1815
|
+
# Finally, if there are any kwargs, update config with it -> highest priority at the end
|
|
1856
1816
|
model_kwargs = generation_config.update(**kwargs)
|
|
1857
|
-
|
|
1817
|
+
|
|
1818
|
+
# Related to #40039: prior to this PR, models with sliding window attention were forced to have
|
|
1819
|
+
# `cache_implementation="hybrid"` (the static sliding window cache). For these models, we now want to use
|
|
1820
|
+
# the dynamic sliding window cache by default, so we UNSET `cache_implementation` if it is a default value.
|
|
1821
|
+
# (if we're inside this branch, then it is because we're using default values from the Hub)
|
|
1822
|
+
if generation_config.cache_implementation == "hybrid":
|
|
1823
|
+
generation_config.cache_implementation = None
|
|
1824
|
+
|
|
1825
|
+
# Finally keep output_xxx args in `model_kwargs` so it can be passed to `forward`
|
|
1858
1826
|
output_attentions = generation_config.output_attentions
|
|
1859
1827
|
output_hidden_states = generation_config.output_hidden_states
|
|
1860
1828
|
model_kwargs.update({"output_attentions": output_attentions} if output_attentions else {})
|
|
@@ -2211,8 +2179,10 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2211
2179
|
"will be skipped."
|
|
2212
2180
|
)
|
|
2213
2181
|
|
|
2214
|
-
|
|
2182
|
+
if can_compile:
|
|
2183
|
+
# Finally: if we can compile, disable tokenizers parallelism
|
|
2215
2184
|
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
|
2185
|
+
|
|
2216
2186
|
# If we use FA2 and a static cache, we cannot compile with fullgraph
|
|
2217
2187
|
if self.config._attn_implementation == "flash_attention_2":
|
|
2218
2188
|
# only raise warning if the user passed an explicit compile-config
|
|
@@ -2225,6 +2195,22 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2225
2195
|
|
|
2226
2196
|
return can_compile
|
|
2227
2197
|
|
|
2198
|
+
@contextmanager
|
|
2199
|
+
def _optimize_model_for_decode(self):
|
|
2200
|
+
original_experts_implementation = self.config._experts_implementation
|
|
2201
|
+
if original_experts_implementation == "grouped_mm":
|
|
2202
|
+
logger.info_once(
|
|
2203
|
+
"We will be switching to 'batched_mm' for the decoding stage as it is much more performant than 'grouped_mm' on smaller inputs. "
|
|
2204
|
+
"If you experience any issues with this, please open an issue on the Hugging Face Transformers GitHub repository.",
|
|
2205
|
+
)
|
|
2206
|
+
self.set_experts_implementation("batched_mm")
|
|
2207
|
+
|
|
2208
|
+
try:
|
|
2209
|
+
yield
|
|
2210
|
+
finally:
|
|
2211
|
+
if original_experts_implementation == "grouped_mm":
|
|
2212
|
+
self.set_experts_implementation(original_experts_implementation)
|
|
2213
|
+
|
|
2228
2214
|
def _get_deprecated_gen_repo(
|
|
2229
2215
|
self,
|
|
2230
2216
|
generation_mode: GenerationMode,
|
|
@@ -2294,7 +2280,6 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2294
2280
|
streamer: Optional["BaseStreamer"] = None,
|
|
2295
2281
|
negative_prompt_ids: torch.Tensor | None = None,
|
|
2296
2282
|
negative_prompt_attention_mask: torch.Tensor | None = None,
|
|
2297
|
-
use_model_defaults: bool | None = None,
|
|
2298
2283
|
custom_generate: str | Callable | None = None,
|
|
2299
2284
|
**kwargs,
|
|
2300
2285
|
) -> GenerateOutput | torch.LongTensor:
|
|
@@ -2360,11 +2345,6 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2360
2345
|
size. This is an experimental feature, subject to breaking API changes in future versions.
|
|
2361
2346
|
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
2362
2347
|
Attention_mask for `negative_prompt_ids`.
|
|
2363
|
-
use_model_defaults (`bool`, *optional*):
|
|
2364
|
-
When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
|
|
2365
|
-
generation configuration (`model.generation_config`), as opposed to the global defaults
|
|
2366
|
-
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
|
|
2367
|
-
`True`.
|
|
2368
2348
|
custom_generate (`str` or `Callable`, *optional*):
|
|
2369
2349
|
One of the following:
|
|
2370
2350
|
- `str` (Hugging Face Hub repository name): runs the custom `generate` function defined at
|
|
@@ -2474,7 +2454,7 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2474
2454
|
# switch to CB
|
|
2475
2455
|
outputs = self.generate_batch(
|
|
2476
2456
|
inputs=inputs,
|
|
2477
|
-
generation_config=self._prepare_generation_config(generation_config,
|
|
2457
|
+
generation_config=self._prepare_generation_config(generation_config, **kwargs)[0],
|
|
2478
2458
|
**kwargs,
|
|
2479
2459
|
)
|
|
2480
2460
|
sequences = [
|
|
@@ -2495,9 +2475,15 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2495
2475
|
streamer,
|
|
2496
2476
|
)
|
|
2497
2477
|
|
|
2498
|
-
|
|
2499
|
-
|
|
2478
|
+
# Check length values before updating the config with defaults. We'll use it later to define the final min/max length (# 6)
|
|
2479
|
+
has_default_max_length = kwargs.get("max_length") is None and (
|
|
2480
|
+
generation_config is None or generation_config.max_length is None
|
|
2500
2481
|
)
|
|
2482
|
+
has_default_min_length = kwargs.get("min_length") is None and (
|
|
2483
|
+
generation_config is None or generation_config.min_length is None
|
|
2484
|
+
)
|
|
2485
|
+
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
|
2486
|
+
|
|
2501
2487
|
generation_mode = generation_config.get_generation_mode(assistant_model)
|
|
2502
2488
|
if isinstance(custom_generate, Callable):
|
|
2503
2489
|
decoding_method = custom_generate
|
|
@@ -2523,7 +2509,6 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2523
2509
|
assistant_model=assistant_model,
|
|
2524
2510
|
negative_prompt_ids=negative_prompt_ids,
|
|
2525
2511
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
|
2526
|
-
use_model_defaults=use_model_defaults,
|
|
2527
2512
|
custom_generate=deprecated_mode_repo,
|
|
2528
2513
|
trust_remote_code=trust_remote_code,
|
|
2529
2514
|
**generation_mode_kwargs,
|
|
@@ -2614,8 +2599,6 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2614
2599
|
|
|
2615
2600
|
# 6. Prepare `max_length` depending on other stopping criteria.
|
|
2616
2601
|
input_ids_length = input_ids.shape[1]
|
|
2617
|
-
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
|
2618
|
-
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
|
2619
2602
|
generation_config = self._prepare_generated_length(
|
|
2620
2603
|
generation_config=generation_config,
|
|
2621
2604
|
has_default_max_length=has_default_max_length,
|
|
@@ -2873,13 +2856,20 @@ class GenerationMixin(ContinuousMixin):
|
|
|
2873
2856
|
else self.__call__
|
|
2874
2857
|
)
|
|
2875
2858
|
|
|
2876
|
-
|
|
2877
|
-
|
|
2859
|
+
# Assisted generation completes the prefill stage in candidate generator so that
|
|
2860
|
+
# we don't have several `prefill` calls in one generation loop. Skip `_prefill` for assistants
|
|
2861
|
+
if not generation_config.is_assistant:
|
|
2862
|
+
outputs = self._prefill(input_ids, generation_config, model_kwargs)
|
|
2863
|
+
prefill_consumed = False
|
|
2864
|
+
else:
|
|
2865
|
+
model_kwargs = self._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
|
|
2866
|
+
prefill_consumed = True
|
|
2878
2867
|
|
|
2879
2868
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
2880
2869
|
if prefill_consumed:
|
|
2881
2870
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
2882
|
-
|
|
2871
|
+
with self._optimize_model_for_decode():
|
|
2872
|
+
outputs = model_forward(**model_inputs, return_dict=True)
|
|
2883
2873
|
prefill_consumed = True
|
|
2884
2874
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
2885
2875
|
outputs,
|
|
@@ -3351,9 +3341,15 @@ class GenerationMixin(ContinuousMixin):
|
|
|
3351
3341
|
)
|
|
3352
3342
|
beam_indices = running_beam_indices.detach().clone()
|
|
3353
3343
|
|
|
3354
|
-
prefill_consumed = False
|
|
3355
3344
|
flat_running_sequences = input_ids
|
|
3356
|
-
|
|
3345
|
+
# Assisted generation completes the prefill stage in candidate generator so that
|
|
3346
|
+
# we don't have several `prefill` calls in one generation loop. Skip `_prefill` for assistants
|
|
3347
|
+
if not generation_config.is_assistant:
|
|
3348
|
+
model_outputs = self._prefill(input_ids, generation_config, model_kwargs)
|
|
3349
|
+
prefill_consumed = False
|
|
3350
|
+
else:
|
|
3351
|
+
model_kwargs = self._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
|
|
3352
|
+
prefill_consumed = True
|
|
3357
3353
|
|
|
3358
3354
|
# 4. run the generation loop
|
|
3359
3355
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
@@ -3659,7 +3655,7 @@ class GenerationMixin(ContinuousMixin):
|
|
|
3659
3655
|
cur_len = input_ids.shape[1]
|
|
3660
3656
|
|
|
3661
3657
|
# 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
|
|
3662
|
-
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
|
3658
|
+
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids, is_first_iteration)
|
|
3663
3659
|
candidate_input_ids = candidate_input_ids.to(self.device)
|
|
3664
3660
|
if candidate_logits is not None:
|
|
3665
3661
|
candidate_logits = candidate_logits.to(self.device)
|
|
@@ -3686,7 +3682,9 @@ class GenerationMixin(ContinuousMixin):
|
|
|
3686
3682
|
dim=0,
|
|
3687
3683
|
)
|
|
3688
3684
|
|
|
3689
|
-
model_inputs = self.prepare_inputs_for_generation(
|
|
3685
|
+
model_inputs = self.prepare_inputs_for_generation(
|
|
3686
|
+
candidate_input_ids, is_first_iteration=is_first_iteration, **candidate_kwargs
|
|
3687
|
+
)
|
|
3690
3688
|
if "logits_to_keep" in model_inputs:
|
|
3691
3689
|
model_inputs["logits_to_keep"] = candidate_length + 1
|
|
3692
3690
|
|
|
@@ -3849,7 +3847,7 @@ class GenerationMixin(ContinuousMixin):
|
|
|
3849
3847
|
def _prefill(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, model_kwargs):
|
|
3850
3848
|
if generation_config.prefill_chunk_size is None:
|
|
3851
3849
|
model_kwargs = self._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
|
|
3852
|
-
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
3850
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, is_first_iteration=True, **model_kwargs)
|
|
3853
3851
|
return self(**model_inputs, return_dict=True)
|
|
3854
3852
|
else: # Chunked prefill
|
|
3855
3853
|
# Even if we are not compiling the forward, flex is always compiled when used. With chunked prefill, we may
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import collections
|
|
17
17
|
from dataclasses import dataclass
|
|
18
18
|
from functools import lru_cache
|
|
19
|
-
from typing import Any
|
|
19
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
20
20
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
import torch
|
|
@@ -24,12 +24,15 @@ from torch import nn
|
|
|
24
24
|
from torch.nn import BCELoss
|
|
25
25
|
|
|
26
26
|
from .. import initialization as init
|
|
27
|
+
from ..configuration_utils import PreTrainedConfig
|
|
27
28
|
from ..modeling_utils import PreTrainedModel
|
|
28
29
|
from ..utils import ModelOutput, logging
|
|
29
|
-
from .configuration_utils import PreTrainedConfig, WatermarkingConfig
|
|
30
30
|
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from .configuration_utils import WatermarkingConfig
|
|
35
|
+
|
|
33
36
|
logger = logging.get_logger(__name__)
|
|
34
37
|
|
|
35
38
|
|
|
@@ -120,13 +123,13 @@ class WatermarkDetector:
|
|
|
120
123
|
|
|
121
124
|
def __init__(
|
|
122
125
|
self,
|
|
123
|
-
model_config: PreTrainedConfig,
|
|
126
|
+
model_config: "PreTrainedConfig",
|
|
124
127
|
device: str,
|
|
125
|
-
watermarking_config: WatermarkingConfig
|
|
128
|
+
watermarking_config: Optional[Union["WatermarkingConfig", dict]],
|
|
126
129
|
ignore_repeated_ngrams: bool = False,
|
|
127
130
|
max_cache_size: int = 128,
|
|
128
131
|
):
|
|
129
|
-
if isinstance(watermarking_config,
|
|
132
|
+
if not isinstance(watermarking_config, dict):
|
|
130
133
|
watermarking_config = watermarking_config.to_dict()
|
|
131
134
|
|
|
132
135
|
self.bos_token_id = (
|
|
@@ -71,8 +71,8 @@ class ImageProcessingMixin(PushToHubMixin):
|
|
|
71
71
|
# This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use
|
|
72
72
|
# `XXXImageProcessor`, this attribute and its value are misleading.
|
|
73
73
|
kwargs.pop("feature_extractor_type", None)
|
|
74
|
-
# Pop "processor_class"
|
|
75
|
-
|
|
74
|
+
# Pop "processor_class", should not be saved with image processing config anymore
|
|
75
|
+
kwargs.pop("processor_class", None)
|
|
76
76
|
# Additional attributes without default values
|
|
77
77
|
for key, value in kwargs.items():
|
|
78
78
|
try:
|
|
@@ -81,10 +81,6 @@ class ImageProcessingMixin(PushToHubMixin):
|
|
|
81
81
|
logger.error(f"Can't set {key} with value {value} for {self}")
|
|
82
82
|
raise err
|
|
83
83
|
|
|
84
|
-
def _set_processor_class(self, processor_class: str):
|
|
85
|
-
"""Sets processor class as an attribute."""
|
|
86
|
-
self._processor_class = processor_class
|
|
87
|
-
|
|
88
84
|
@classmethod
|
|
89
85
|
def from_pretrained(
|
|
90
86
|
cls: type[ImageProcessorType],
|
|
@@ -428,12 +424,6 @@ class ImageProcessingMixin(PushToHubMixin):
|
|
|
428
424
|
if isinstance(value, np.ndarray):
|
|
429
425
|
dictionary[key] = value.tolist()
|
|
430
426
|
|
|
431
|
-
# make sure private name "_processor_class" is correctly
|
|
432
|
-
# saved as "processor_class"
|
|
433
|
-
_processor_class = dictionary.pop("_processor_class", None)
|
|
434
|
-
if _processor_class is not None:
|
|
435
|
-
dictionary["processor_class"] = _processor_class
|
|
436
|
-
|
|
437
427
|
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
|
|
438
428
|
|
|
439
429
|
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
|
@@ -932,11 +932,22 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
|
|
932
932
|
if do_pad:
|
|
933
933
|
processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
|
|
934
934
|
|
|
935
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
936
935
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
937
936
|
|
|
938
937
|
def to_dict(self):
|
|
939
938
|
encoder_dict = super().to_dict()
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
939
|
+
|
|
940
|
+
# Filter out None values that are class defaults, but preserve explicitly set None values
|
|
941
|
+
filtered_dict = {}
|
|
942
|
+
for key, value in encoder_dict.items():
|
|
943
|
+
if value is None:
|
|
944
|
+
class_default = getattr(type(self), key, "NOT_FOUND")
|
|
945
|
+
# Keep None if user explicitly set it (class default is non-None)
|
|
946
|
+
if class_default != "NOT_FOUND" and class_default is not None:
|
|
947
|
+
filtered_dict[key] = value
|
|
948
|
+
else:
|
|
949
|
+
filtered_dict[key] = value
|
|
950
|
+
|
|
951
|
+
filtered_dict.pop("_valid_processor_keys", None)
|
|
952
|
+
filtered_dict.pop("_valid_kwargs_names", None)
|
|
953
|
+
return filtered_dict
|
transformers/initialization.py
CHANGED
|
@@ -206,3 +206,40 @@ def guard_torch_init_functions():
|
|
|
206
206
|
for module, functions in originals.items():
|
|
207
207
|
for func_name, func in functions.items():
|
|
208
208
|
setattr(module, func_name, func)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@contextmanager
|
|
212
|
+
def no_init_weights():
|
|
213
|
+
"""
|
|
214
|
+
Disable weight initialization both at the torch-level, and at the transformers-level (`init_weights`).
|
|
215
|
+
This is used to speed-up initializing an empty model with deepspeed, as we do not initialize the model on meta device
|
|
216
|
+
with deepspeed, but we still don't need to run expensive weight initializations as we are loading params afterwards.
|
|
217
|
+
"""
|
|
218
|
+
from .modeling_utils import PreTrainedModel
|
|
219
|
+
|
|
220
|
+
def empty_func(*args, **kwargs):
|
|
221
|
+
pass
|
|
222
|
+
|
|
223
|
+
originals = defaultdict(dict)
|
|
224
|
+
try:
|
|
225
|
+
# Replace all torch funcs by empty ones
|
|
226
|
+
for module_name in TORCH_MODULES_TO_PATCH:
|
|
227
|
+
if module_name in sys.modules:
|
|
228
|
+
module = sys.modules[module_name]
|
|
229
|
+
for func_name in TORCH_INIT_FUNCTIONS.keys():
|
|
230
|
+
if hasattr(module, func_name):
|
|
231
|
+
originals[module][func_name] = getattr(module, func_name)
|
|
232
|
+
setattr(module, func_name, empty_func)
|
|
233
|
+
|
|
234
|
+
# Also patch our own `init_weights`
|
|
235
|
+
original_init_weights = PreTrainedModel.init_weights
|
|
236
|
+
PreTrainedModel.init_weights = empty_func
|
|
237
|
+
|
|
238
|
+
yield
|
|
239
|
+
finally:
|
|
240
|
+
# Set back the original torch functions on all modules
|
|
241
|
+
for module, functions in originals.items():
|
|
242
|
+
for func_name, func in functions.items():
|
|
243
|
+
setattr(module, func_name, func)
|
|
244
|
+
# Set back `init_weights`
|
|
245
|
+
PreTrainedModel.init_weights = original_init_weights
|
|
@@ -69,6 +69,7 @@ _import_structure = {
|
|
|
69
69
|
"hqq": ["prepare_for_hqq_linear"],
|
|
70
70
|
"hub_kernels": [
|
|
71
71
|
"LayerRepository",
|
|
72
|
+
"lazy_load_kernel",
|
|
72
73
|
"register_kernel_mapping",
|
|
73
74
|
"replace_kernel_forward_from_hub",
|
|
74
75
|
"use_kernel_forward_from_hub",
|
|
@@ -116,6 +117,11 @@ _import_structure = {
|
|
|
116
117
|
"run_hp_search_ray",
|
|
117
118
|
"run_hp_search_wandb",
|
|
118
119
|
],
|
|
120
|
+
"moe": [
|
|
121
|
+
"batched_mm_experts_forward",
|
|
122
|
+
"grouped_mm_experts_forward",
|
|
123
|
+
"use_experts_implementation",
|
|
124
|
+
],
|
|
119
125
|
"mxfp4": [
|
|
120
126
|
"Mxfp4GptOssExperts",
|
|
121
127
|
"convert_moe_packed_tensors",
|
|
@@ -211,6 +217,7 @@ if TYPE_CHECKING:
|
|
|
211
217
|
from .hqq import prepare_for_hqq_linear
|
|
212
218
|
from .hub_kernels import (
|
|
213
219
|
LayerRepository,
|
|
220
|
+
lazy_load_kernel,
|
|
214
221
|
register_kernel_mapping,
|
|
215
222
|
replace_kernel_forward_from_hub,
|
|
216
223
|
use_kernel_forward_from_hub,
|
|
@@ -258,6 +265,11 @@ if TYPE_CHECKING:
|
|
|
258
265
|
run_hp_search_ray,
|
|
259
266
|
run_hp_search_wandb,
|
|
260
267
|
)
|
|
268
|
+
from .moe import (
|
|
269
|
+
batched_mm_experts_forward,
|
|
270
|
+
grouped_mm_experts_forward,
|
|
271
|
+
use_experts_implementation,
|
|
272
|
+
)
|
|
261
273
|
from .mxfp4 import (
|
|
262
274
|
Mxfp4GptOssExperts,
|
|
263
275
|
dequantize,
|