transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -20,12 +20,11 @@ import os
|
|
|
20
20
|
from abc import ABC, abstractmethod
|
|
21
21
|
from collections.abc import Callable
|
|
22
22
|
from dataclasses import dataclass, is_dataclass
|
|
23
|
-
from typing import TYPE_CHECKING, Any, Optional
|
|
23
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
24
24
|
|
|
25
25
|
from huggingface_hub import create_repo
|
|
26
26
|
|
|
27
27
|
from .. import __version__
|
|
28
|
-
from ..configuration_utils import PreTrainedConfig
|
|
29
28
|
from ..utils import (
|
|
30
29
|
GENERATION_CONFIG_NAME,
|
|
31
30
|
ExplicitEnum,
|
|
@@ -38,6 +37,7 @@ from ..utils import (
|
|
|
38
37
|
|
|
39
38
|
|
|
40
39
|
if TYPE_CHECKING:
|
|
40
|
+
from ..configuration_utils import PreTrainedConfig
|
|
41
41
|
from ..modeling_utils import PreTrainedModel
|
|
42
42
|
|
|
43
43
|
|
|
@@ -104,17 +104,18 @@ class GenerationConfig(PushToHubMixin):
|
|
|
104
104
|
Arg:
|
|
105
105
|
> Parameters that control the length of the output
|
|
106
106
|
|
|
107
|
-
max_length (`int`, *optional
|
|
108
|
-
|
|
109
|
-
`
|
|
107
|
+
max_length (`int`, *optional*):
|
|
108
|
+
`max_new_tokens` is recommended for controlling how many tokens the model generates.
|
|
109
|
+
`max_length` remains for backward compatibility.
|
|
110
|
+
|
|
110
111
|
max_new_tokens (`int`, *optional*):
|
|
111
112
|
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
|
112
|
-
min_length (`int`, *optional
|
|
113
|
+
min_length (`int`, *optional*):
|
|
113
114
|
The minimum length of the sequence to be generated. Corresponds to the length of the input prompt +
|
|
114
115
|
`min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
|
|
115
116
|
min_new_tokens (`int`, *optional*):
|
|
116
117
|
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
|
117
|
-
early_stopping (`bool` or `str`, *optional
|
|
118
|
+
early_stopping (`bool` or `str`, *optional*):
|
|
118
119
|
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
|
|
119
120
|
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
|
|
120
121
|
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
|
@@ -128,17 +129,17 @@ class GenerationConfig(PushToHubMixin):
|
|
|
128
129
|
|
|
129
130
|
> Parameters that control the generation strategy used
|
|
130
131
|
|
|
131
|
-
do_sample (`bool`,
|
|
132
|
+
do_sample (`bool`, defaults to `False`):
|
|
132
133
|
Whether or not to use sampling ; use greedy decoding otherwise.
|
|
133
|
-
num_beams (`int`, *optional
|
|
134
|
+
num_beams (`int`, *optional*):
|
|
134
135
|
Number of beams for beam search. 1 means no beam search.
|
|
135
136
|
|
|
136
137
|
> Parameters that control the cache
|
|
137
138
|
|
|
138
|
-
use_cache (`bool`,
|
|
139
|
+
use_cache (`bool`, defaults to `True`):
|
|
139
140
|
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
|
140
141
|
speed up decoding.
|
|
141
|
-
cache_implementation (`str`, *optional
|
|
142
|
+
cache_implementation (`str`, *optional*):
|
|
142
143
|
Name of the cache class that will be instantiated in `generate`, for faster decoding. Possible values are:
|
|
143
144
|
|
|
144
145
|
- `"dynamic"`: [`DynamicCache`]
|
|
@@ -154,11 +155,11 @@ class GenerationConfig(PushToHubMixin):
|
|
|
154
155
|
|
|
155
156
|
> Parameters for manipulation of the model output logits
|
|
156
157
|
|
|
157
|
-
temperature (`float`, *optional
|
|
158
|
+
temperature (`float`, *optional*):
|
|
158
159
|
The value used to module the next token probabilities. This value is set in a model's `generation_config.json` file. If it isn't set, the default value is 1.0
|
|
159
|
-
top_k (`int`, *optional
|
|
160
|
+
top_k (`int`, *optional*):
|
|
160
161
|
The number of highest probability vocabulary tokens to keep for top-k-filtering. This value is set in a model's `generation_config.json` file. If it isn't set, the default value is 50.
|
|
161
|
-
top_p (`float`, *optional
|
|
162
|
+
top_p (`float`, *optional*):
|
|
162
163
|
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to
|
|
163
164
|
`top_p` or higher are kept for generation. This value is set in a model's `generation_config.json` file. If it isn't set, the default value is 1.0
|
|
164
165
|
min_p (`float`, *optional*):
|
|
@@ -171,41 +172,41 @@ class GenerationConfig(PushToHubMixin):
|
|
|
171
172
|
is kept whose *renormalized* entropy is less than or equal to `top_h` times the entropy of the full distribution.
|
|
172
173
|
Smaller values (e.g., 0.2–0.5) lead to more focused, deterministic outputs, while values closer to 1.0 allow more
|
|
173
174
|
randomness and diversity. Typical values are in the 0.3–0.6 range.
|
|
174
|
-
typical_p (`float`, *optional
|
|
175
|
+
typical_p (`float`, *optional*):
|
|
175
176
|
Local typicality measures how similar the conditional probability of predicting a target token next is to
|
|
176
177
|
the expected conditional probability of predicting a random token next, given the partial text already
|
|
177
178
|
generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that
|
|
178
179
|
add up to `typical_p` or higher are kept for generation. See [this
|
|
179
180
|
paper](https://huggingface.co/papers/2202.00666) for more details.
|
|
180
|
-
epsilon_cutoff (`float`, *optional
|
|
181
|
+
epsilon_cutoff (`float`, *optional*):
|
|
181
182
|
If set to float strictly between 0 and 1, only tokens with a conditional probability greater than
|
|
182
183
|
`epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the
|
|
183
184
|
size of the model. See [Truncation Sampling as Language Model
|
|
184
185
|
Desmoothing](https://huggingface.co/papers/2210.15191) for more details.
|
|
185
|
-
eta_cutoff (`float`, *optional
|
|
186
|
+
eta_cutoff (`float`, *optional*):
|
|
186
187
|
Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between
|
|
187
188
|
0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) *
|
|
188
189
|
exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token
|
|
189
190
|
probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3,
|
|
190
191
|
depending on the size of the model. See [Truncation Sampling as Language Model
|
|
191
192
|
Desmoothing](https://huggingface.co/papers/2210.15191) for more details.
|
|
192
|
-
repetition_penalty (`float`, *optional
|
|
193
|
+
repetition_penalty (`float`, *optional*):
|
|
193
194
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
|
194
195
|
paper](https://huggingface.co/papers/1909.05858) for more details.
|
|
195
|
-
encoder_repetition_penalty (`float`, *optional
|
|
196
|
+
encoder_repetition_penalty (`float`, *optional*):
|
|
196
197
|
The parameter for encoder_repetition_penalty. An exponential penalty on sequences that are not in the
|
|
197
198
|
original input. 1.0 means no penalty.
|
|
198
|
-
length_penalty (`float`, *optional
|
|
199
|
+
length_penalty (`float`, *optional*):
|
|
199
200
|
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
|
200
201
|
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
|
201
202
|
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
|
202
203
|
`length_penalty` < 0.0 encourages shorter sequences.
|
|
203
|
-
no_repeat_ngram_size (`int`, *optional
|
|
204
|
+
no_repeat_ngram_size (`int`, *optional*):
|
|
204
205
|
If set to int > 0, all ngrams of that size can only occur once.
|
|
205
206
|
bad_words_ids (`list[list[int]]`, *optional*):
|
|
206
207
|
List of list of token ids that are not allowed to be generated. Check
|
|
207
208
|
[`~generation.NoBadWordsLogitsProcessor`] for further documentation and examples.
|
|
208
|
-
renormalize_logits (`bool`,
|
|
209
|
+
renormalize_logits (`bool`, defaults to `False`):
|
|
209
210
|
Whether to renormalize the logits after applying all the logits processors (including the custom
|
|
210
211
|
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
|
|
211
212
|
are normalized but some logit processors break the normalization.
|
|
@@ -216,7 +217,7 @@ class GenerationConfig(PushToHubMixin):
|
|
|
216
217
|
forced_eos_token_id (`int` or list[int]`, *optional*, defaults to `model.config.forced_eos_token_id`):
|
|
217
218
|
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
|
|
218
219
|
list to set multiple *end-of-sequence* tokens.
|
|
219
|
-
remove_invalid_values (`bool`,
|
|
220
|
+
remove_invalid_values (`bool`, defaults to `model.config.remove_invalid_values`):
|
|
220
221
|
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash.
|
|
221
222
|
Note that using `remove_invalid_values` can slow down generation.
|
|
222
223
|
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
|
|
@@ -233,7 +234,7 @@ class GenerationConfig(PushToHubMixin):
|
|
|
233
234
|
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
|
234
235
|
sequence being selected, while negative biases do the opposite. Check
|
|
235
236
|
[`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
|
|
236
|
-
token_healing (`bool`,
|
|
237
|
+
token_healing (`bool`, defaults to `False`):
|
|
237
238
|
Heal tail tokens of prompts by replacing them with their appropriate extensions.
|
|
238
239
|
This enhances the quality of completions for prompts affected by greedy tokenization bias.
|
|
239
240
|
guidance_scale (`float`, *optional*):
|
|
@@ -249,18 +250,18 @@ class GenerationConfig(PushToHubMixin):
|
|
|
249
250
|
|
|
250
251
|
num_return_sequences (`int`, *optional*, defaults to 1):
|
|
251
252
|
The number of independently computed returned sequences for each element in the batch.
|
|
252
|
-
output_attentions (`bool`,
|
|
253
|
+
output_attentions (`bool`, defaults to `False`):
|
|
253
254
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
254
255
|
tensors for more details.
|
|
255
|
-
output_hidden_states (`bool`,
|
|
256
|
+
output_hidden_states (`bool`, defaults to `False`):
|
|
256
257
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
257
258
|
more details.
|
|
258
|
-
output_scores (`bool`,
|
|
259
|
+
output_scores (`bool`, defaults to `False`):
|
|
259
260
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
|
260
|
-
output_logits (`bool`,
|
|
261
|
+
output_logits (`bool`, defaults to `False`):
|
|
261
262
|
Whether or not to return the unprocessed prediction logit scores. See `logits` under returned tensors for
|
|
262
263
|
more details.
|
|
263
|
-
return_dict_in_generate (`bool`,
|
|
264
|
+
return_dict_in_generate (`bool`, defaults to `False`):
|
|
264
265
|
Whether or not to return a [`~utils.ModelOutput`], as opposed to returning exclusively the generated
|
|
265
266
|
sequence. This flag must be set to `True` to return the generation cache (when `use_cache` is `True`)
|
|
266
267
|
or optional outputs (see flags starting with `output_`)
|
|
@@ -276,7 +277,7 @@ class GenerationConfig(PushToHubMixin):
|
|
|
276
277
|
|
|
277
278
|
> Generation parameters exclusive to encoder-decoder models
|
|
278
279
|
|
|
279
|
-
encoder_no_repeat_ngram_size (`int`, *optional
|
|
280
|
+
encoder_no_repeat_ngram_size (`int`, *optional*):
|
|
280
281
|
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
|
|
281
282
|
`decoder_input_ids`.
|
|
282
283
|
decoder_start_token_id (`int` or `list[int]`, *optional*):
|
|
@@ -285,20 +286,20 @@ class GenerationConfig(PushToHubMixin):
|
|
|
285
286
|
(e.g. multilingual models with different target languages in one batch)
|
|
286
287
|
|
|
287
288
|
> Generation parameters exclusive to assistant generation
|
|
288
|
-
is_assistant (`bool`,
|
|
289
|
+
is_assistant (`bool`, defaults to `False`):
|
|
289
290
|
Whether the model is an assistant (draft) model.
|
|
290
|
-
num_assistant_tokens (`int`, *optional
|
|
291
|
+
num_assistant_tokens (`int`, *optional*):
|
|
291
292
|
Defines the number of _speculative tokens_ that shall be generated by the assistant model before being
|
|
292
293
|
checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation
|
|
293
294
|
more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant
|
|
294
295
|
model requires lots of corrections, lower speed-ups are reached.
|
|
295
|
-
num_assistant_tokens_schedule (`str`, *optional
|
|
296
|
+
num_assistant_tokens_schedule (`str`, *optional*):
|
|
296
297
|
Defines the schedule at which max assistant tokens shall be changed during inference.
|
|
297
298
|
- `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else
|
|
298
299
|
reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
|
|
299
300
|
- `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
|
|
300
301
|
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
|
|
301
|
-
assistant_confidence_threshold (`float`, *optional
|
|
302
|
+
assistant_confidence_threshold (`float`, *optional*):
|
|
302
303
|
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
|
|
303
304
|
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
|
|
304
305
|
(defined by `num_assistant_tokens`) is not yet reached. The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes, biased towards avoiding false negatives.
|
|
@@ -312,11 +313,11 @@ class GenerationConfig(PushToHubMixin):
|
|
|
312
313
|
assistant_early_exit(`int`, *optional*):
|
|
313
314
|
If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
|
|
314
315
|
models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head).
|
|
315
|
-
assistant_lookbehind(`int`, *optional
|
|
316
|
+
assistant_lookbehind(`int`, *optional*):
|
|
316
317
|
If set to a positive integer, the re-encodeing process will additionally consider the last `assistant_lookbehind` assistant tokens
|
|
317
318
|
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
|
|
318
319
|
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
|
|
319
|
-
target_lookbehind(`int`, *optional
|
|
320
|
+
target_lookbehind(`int`, *optional*):
|
|
320
321
|
If set to a positive integer, the re-encodeing process will additionally consider the last `target_lookbehind` target tokens
|
|
321
322
|
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
|
|
322
323
|
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
|
|
@@ -326,7 +327,7 @@ class GenerationConfig(PushToHubMixin):
|
|
|
326
327
|
compile_config (CompileConfig, *optional*):
|
|
327
328
|
If using a compilable cache, this controls how `generate` will `compile` the forward pass for faster
|
|
328
329
|
inference.
|
|
329
|
-
disable_compile (`bool`,
|
|
330
|
+
disable_compile (`bool`, defaults to `False`):
|
|
330
331
|
Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when
|
|
331
332
|
specific criteria are met, including using a compilable cache. Please open an issue if you find the
|
|
332
333
|
need to use this flag.
|
|
@@ -336,38 +337,36 @@ class GenerationConfig(PushToHubMixin):
|
|
|
336
337
|
|
|
337
338
|
def __init__(self, **kwargs):
|
|
338
339
|
# Parameters that control the length of the output
|
|
339
|
-
self.max_length = kwargs.pop("max_length",
|
|
340
|
+
self.max_length = kwargs.pop("max_length", None)
|
|
340
341
|
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
|
|
341
|
-
self.min_length = kwargs.pop("min_length",
|
|
342
|
+
self.min_length = kwargs.pop("min_length", None)
|
|
342
343
|
self.min_new_tokens = kwargs.pop("min_new_tokens", None)
|
|
343
|
-
self.early_stopping = kwargs.pop("early_stopping",
|
|
344
|
+
self.early_stopping = kwargs.pop("early_stopping", None)
|
|
344
345
|
self.max_time = kwargs.pop("max_time", None)
|
|
345
346
|
self.stop_strings = kwargs.pop("stop_strings", None)
|
|
346
347
|
|
|
347
348
|
# Parameters that control the generation strategy used
|
|
348
349
|
self.do_sample = kwargs.pop("do_sample", False)
|
|
349
|
-
self.num_beams = kwargs.pop("num_beams",
|
|
350
|
+
self.num_beams = kwargs.pop("num_beams", None)
|
|
350
351
|
|
|
351
352
|
# Parameters that control the cache
|
|
352
353
|
self.use_cache = kwargs.pop("use_cache", True)
|
|
353
354
|
self.cache_implementation = kwargs.pop("cache_implementation", None)
|
|
354
355
|
self.cache_config = kwargs.pop("cache_config", None)
|
|
355
356
|
|
|
356
|
-
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None)
|
|
357
|
-
|
|
358
357
|
# Parameters for manipulation of the model output logits
|
|
359
|
-
self.temperature = kwargs.pop("temperature",
|
|
360
|
-
self.top_k = kwargs.pop("top_k",
|
|
361
|
-
self.top_p = kwargs.pop("top_p",
|
|
358
|
+
self.temperature = kwargs.pop("temperature", None)
|
|
359
|
+
self.top_k = kwargs.pop("top_k", None)
|
|
360
|
+
self.top_p = kwargs.pop("top_p", None)
|
|
362
361
|
self.min_p = kwargs.pop("min_p", None)
|
|
363
362
|
self.top_h = kwargs.pop("top_h", None)
|
|
364
|
-
self.typical_p = kwargs.pop("typical_p",
|
|
365
|
-
self.epsilon_cutoff = kwargs.pop("epsilon_cutoff",
|
|
366
|
-
self.eta_cutoff = kwargs.pop("eta_cutoff",
|
|
367
|
-
self.repetition_penalty = kwargs.pop("repetition_penalty",
|
|
368
|
-
self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty",
|
|
369
|
-
self.length_penalty = kwargs.pop("length_penalty",
|
|
370
|
-
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size",
|
|
363
|
+
self.typical_p = kwargs.pop("typical_p", None)
|
|
364
|
+
self.epsilon_cutoff = kwargs.pop("epsilon_cutoff", None)
|
|
365
|
+
self.eta_cutoff = kwargs.pop("eta_cutoff", None)
|
|
366
|
+
self.repetition_penalty = kwargs.pop("repetition_penalty", None)
|
|
367
|
+
self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty", None)
|
|
368
|
+
self.length_penalty = kwargs.pop("length_penalty", None)
|
|
369
|
+
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", None)
|
|
371
370
|
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
|
372
371
|
self.renormalize_logits = kwargs.pop("renormalize_logits", False)
|
|
373
372
|
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
|
@@ -380,20 +379,16 @@ class GenerationConfig(PushToHubMixin):
|
|
|
380
379
|
self.token_healing = kwargs.pop("token_healing", False)
|
|
381
380
|
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
|
382
381
|
|
|
383
|
-
watermarking_config = kwargs.pop("watermarking_config", None)
|
|
384
|
-
if watermarking_config
|
|
385
|
-
self.watermarking_config =
|
|
386
|
-
elif isinstance(watermarking_config, BaseWatermarkingConfig):
|
|
387
|
-
self.watermarking_config = watermarking_config
|
|
388
|
-
else:
|
|
389
|
-
self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config)
|
|
382
|
+
self.watermarking_config = kwargs.pop("watermarking_config", None)
|
|
383
|
+
if isinstance(self.watermarking_config, dict):
|
|
384
|
+
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
|
390
385
|
|
|
391
386
|
# Parameters that define the output variables of `generate`
|
|
392
387
|
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
|
393
388
|
self.output_attentions = kwargs.pop("output_attentions", False)
|
|
394
389
|
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
|
395
390
|
self.output_scores = kwargs.pop("output_scores", False)
|
|
396
|
-
self.output_logits = kwargs.pop("output_logits",
|
|
391
|
+
self.output_logits = kwargs.pop("output_logits", False)
|
|
397
392
|
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
|
398
393
|
|
|
399
394
|
# Special tokens that can be used at generation time
|
|
@@ -402,57 +397,57 @@ class GenerationConfig(PushToHubMixin):
|
|
|
402
397
|
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
|
403
398
|
|
|
404
399
|
# Generation parameters exclusive to encoder-decoder models
|
|
405
|
-
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size",
|
|
400
|
+
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", None)
|
|
406
401
|
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
|
407
402
|
|
|
408
403
|
# Assistant generation
|
|
409
|
-
self.is_assistant = False
|
|
410
|
-
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens",
|
|
411
|
-
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule",
|
|
412
|
-
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold",
|
|
404
|
+
self.is_assistant = kwargs.pop("is_assistant", False)
|
|
405
|
+
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", None)
|
|
406
|
+
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", None)
|
|
407
|
+
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", None)
|
|
413
408
|
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
|
414
409
|
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
|
|
415
410
|
self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)
|
|
416
|
-
|
|
417
|
-
self.
|
|
418
|
-
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
|
|
411
|
+
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", None)
|
|
412
|
+
self.target_lookbehind = kwargs.pop("target_lookbehind", None)
|
|
419
413
|
|
|
420
414
|
# Performance
|
|
421
415
|
self.compile_config = kwargs.pop("compile_config", None)
|
|
422
416
|
self.disable_compile = kwargs.pop("disable_compile", False)
|
|
423
417
|
|
|
424
|
-
# Deprecated (moved to the Hub). TODO
|
|
418
|
+
# Deprecated (moved to the Hub). TODO remove for v5
|
|
425
419
|
self.low_memory = kwargs.pop("low_memory", None)
|
|
426
420
|
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
|
|
427
421
|
self.dola_layers = kwargs.pop("dola_layers", None)
|
|
428
|
-
self.diversity_penalty = kwargs.pop("diversity_penalty",
|
|
429
|
-
self.num_beam_groups = kwargs.pop("num_beam_groups",
|
|
422
|
+
self.diversity_penalty = kwargs.pop("diversity_penalty", None)
|
|
423
|
+
self.num_beam_groups = kwargs.pop("num_beam_groups", None)
|
|
430
424
|
self.constraints = kwargs.pop("constraints", None)
|
|
431
425
|
self.force_words_ids = kwargs.pop("force_words_ids", None)
|
|
432
426
|
|
|
433
|
-
|
|
434
|
-
# interface.
|
|
435
|
-
self._from_model_config = kwargs.pop("_from_model_config", False)
|
|
436
|
-
self._commit_hash = kwargs.pop("_commit_hash", None)
|
|
437
|
-
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
|
427
|
+
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None)
|
|
438
428
|
|
|
439
|
-
#
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
f"Please make sure the generation config includes `forced_bos_token_id={self.bos_token_id}`. "
|
|
444
|
-
)
|
|
429
|
+
# Common attributes
|
|
430
|
+
self._commit_hash = kwargs.pop("_commit_hash", None)
|
|
431
|
+
self._from_model_config = kwargs.pop("_from_model_config", None)
|
|
432
|
+
self.transformers_version = kwargs.pop("transformers_version", None)
|
|
445
433
|
|
|
446
434
|
# Additional attributes without default values
|
|
447
435
|
if not self._from_model_config:
|
|
448
|
-
# we don't want to copy values from the model config if we're initializing
|
|
449
|
-
# model's default configuration file
|
|
436
|
+
# we don't want to copy values from the model config if we're initializing
|
|
437
|
+
# a `GenerationConfig` from a model's default configuration file
|
|
450
438
|
for key, value in kwargs.items():
|
|
451
439
|
try:
|
|
452
440
|
setattr(self, key, value)
|
|
453
441
|
except AttributeError as err:
|
|
454
442
|
logger.error(f"Can't set {key} with value {value} for {self}")
|
|
455
443
|
raise err
|
|
444
|
+
else:
|
|
445
|
+
# Ensure backward compatibility for models that use `forced_bos_token_id` within their config
|
|
446
|
+
if kwargs.get("force_bos_token_to_be_generated", False):
|
|
447
|
+
self.forced_bos_token_id = self.bos_token_id
|
|
448
|
+
logger.warning_once(
|
|
449
|
+
f"Please make sure the generation config includes `forced_bos_token_id={self.bos_token_id}`. "
|
|
450
|
+
)
|
|
456
451
|
|
|
457
452
|
# Validate the values of the attributes
|
|
458
453
|
self.validate()
|
|
@@ -487,8 +482,8 @@ class GenerationConfig(PushToHubMixin):
|
|
|
487
482
|
# property and part of the `__repr__`
|
|
488
483
|
if self.constraints is not None or self.force_words_ids is not None:
|
|
489
484
|
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
|
|
490
|
-
elif self.num_beams == 1:
|
|
491
|
-
if self.do_sample
|
|
485
|
+
elif self.num_beams is None or self.num_beams == 1:
|
|
486
|
+
if not self.do_sample:
|
|
492
487
|
if (
|
|
493
488
|
self.top_k is not None
|
|
494
489
|
and self.top_k > 1
|
|
@@ -501,9 +496,9 @@ class GenerationConfig(PushToHubMixin):
|
|
|
501
496
|
else:
|
|
502
497
|
generation_mode = GenerationMode.SAMPLE
|
|
503
498
|
else:
|
|
504
|
-
if self.num_beam_groups > 1:
|
|
499
|
+
if self.num_beam_groups is not None and self.num_beam_groups > 1:
|
|
505
500
|
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
|
|
506
|
-
elif self.do_sample
|
|
501
|
+
elif self.do_sample:
|
|
507
502
|
generation_mode = GenerationMode.BEAM_SAMPLE
|
|
508
503
|
else:
|
|
509
504
|
generation_mode = GenerationMode.BEAM_SEARCH
|
|
@@ -536,6 +531,45 @@ class GenerationConfig(PushToHubMixin):
|
|
|
536
531
|
)
|
|
537
532
|
return generation_mode
|
|
538
533
|
|
|
534
|
+
@staticmethod
|
|
535
|
+
def _get_default_generation_params() -> dict[str, Any]:
|
|
536
|
+
return {
|
|
537
|
+
"max_length": 20,
|
|
538
|
+
"min_length": 0,
|
|
539
|
+
"do_sample": False,
|
|
540
|
+
"early_stopping": False,
|
|
541
|
+
"num_beams": 1,
|
|
542
|
+
"temperature": 1.0,
|
|
543
|
+
"top_k": 50,
|
|
544
|
+
"top_p": 1.0,
|
|
545
|
+
"typical_p": 1.0,
|
|
546
|
+
"repetition_penalty": 1.0,
|
|
547
|
+
"length_penalty": 1.0,
|
|
548
|
+
"no_repeat_ngram_size": 0,
|
|
549
|
+
"encoder_no_repeat_ngram_size": 0,
|
|
550
|
+
"bad_words_ids": None,
|
|
551
|
+
"num_return_sequences": 1,
|
|
552
|
+
"output_scores": False,
|
|
553
|
+
"return_dict_in_generate": False,
|
|
554
|
+
"forced_bos_token_id": None,
|
|
555
|
+
"forced_eos_token_id": None,
|
|
556
|
+
"remove_invalid_values": False,
|
|
557
|
+
"exponential_decay_length_penalty": None,
|
|
558
|
+
"suppress_tokens": None,
|
|
559
|
+
"begin_suppress_tokens": None,
|
|
560
|
+
"epsilon_cutoff": 0.0,
|
|
561
|
+
"eta_cutoff": 0.0,
|
|
562
|
+
"encoder_repetition_penalty": 1.0,
|
|
563
|
+
"num_assistant_tokens": 20,
|
|
564
|
+
"num_assistant_tokens_schedule": "constant",
|
|
565
|
+
"assistant_confidence_threshold": 0.4,
|
|
566
|
+
"assistant_lookbehind": 10,
|
|
567
|
+
"target_lookbehind": 10,
|
|
568
|
+
# Deprecated arguments (moved to the Hub). TODO joao, manuel: remove in v4.62.0
|
|
569
|
+
"num_beam_groups": 1,
|
|
570
|
+
"diversity_penalty": 0.0,
|
|
571
|
+
}
|
|
572
|
+
|
|
539
573
|
def validate(self, strict=False):
|
|
540
574
|
"""
|
|
541
575
|
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
|
|
@@ -551,7 +585,7 @@ class GenerationConfig(PushToHubMixin):
|
|
|
551
585
|
|
|
552
586
|
# 1. Validation of individual attributes
|
|
553
587
|
# 1.1. Decoding attributes
|
|
554
|
-
if self.early_stopping not in {True, False, "never"}:
|
|
588
|
+
if self.early_stopping not in {None, True, False, "never"}:
|
|
555
589
|
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
|
556
590
|
if self.max_new_tokens is not None and self.max_new_tokens <= 0:
|
|
557
591
|
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
|
|
@@ -582,9 +616,9 @@ class GenerationConfig(PushToHubMixin):
|
|
|
582
616
|
|
|
583
617
|
# 2. Validation of attribute combinations
|
|
584
618
|
# 2.1. detect sampling-only parameterization when not in sampling mode
|
|
585
|
-
if self.do_sample
|
|
619
|
+
if not self.do_sample:
|
|
586
620
|
greedy_wrong_parameter_msg = (
|
|
587
|
-
"`do_sample` is set to `
|
|
621
|
+
"`do_sample` is set not to set `True`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
|
|
588
622
|
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
|
|
589
623
|
)
|
|
590
624
|
if self.temperature is not None and self.temperature != 1.0:
|
|
@@ -613,42 +647,42 @@ class GenerationConfig(PushToHubMixin):
|
|
|
613
647
|
)
|
|
614
648
|
|
|
615
649
|
# 2.2. detect beam-only parameterization when not in beam mode
|
|
616
|
-
if self.num_beams == 1:
|
|
650
|
+
if self.num_beams is None or self.num_beams == 1:
|
|
617
651
|
single_beam_wrong_parameter_msg = (
|
|
618
|
-
"`num_beams` is set to
|
|
652
|
+
"`num_beams` is set to {num_beams}. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
|
|
619
653
|
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`."
|
|
620
654
|
)
|
|
621
|
-
if self.early_stopping is not False:
|
|
655
|
+
if self.early_stopping is not None and self.early_stopping is not False:
|
|
622
656
|
minor_issues["early_stopping"] = single_beam_wrong_parameter_msg.format(
|
|
623
|
-
flag_name="early_stopping", flag_value=self.early_stopping
|
|
657
|
+
num_beams=self.num_beams, flag_name="early_stopping", flag_value=self.early_stopping
|
|
624
658
|
)
|
|
625
659
|
if self.length_penalty is not None and self.length_penalty != 1.0:
|
|
626
660
|
minor_issues["length_penalty"] = single_beam_wrong_parameter_msg.format(
|
|
627
|
-
flag_name="length_penalty", flag_value=self.length_penalty
|
|
661
|
+
num_beams=self.num_beams, flag_name="length_penalty", flag_value=self.length_penalty
|
|
628
662
|
)
|
|
629
663
|
|
|
630
664
|
# 2.4. check `num_return_sequences`
|
|
631
|
-
if self.num_return_sequences
|
|
632
|
-
if self.num_beams == 1:
|
|
633
|
-
if self.do_sample
|
|
665
|
+
if self.num_return_sequences > 1:
|
|
666
|
+
if self.num_beams is None or self.num_beams == 1:
|
|
667
|
+
if not self.do_sample:
|
|
634
668
|
raise ValueError(
|
|
635
|
-
"Greedy methods without beam search do not support
|
|
636
|
-
f"(got {self.num_return_sequences})."
|
|
669
|
+
"Greedy methods (do_sample != True) without beam search do not support "
|
|
670
|
+
f"`num_return_sequences` different than 1 (got {self.num_return_sequences})."
|
|
637
671
|
)
|
|
638
|
-
elif self.num_return_sequences > self.num_beams:
|
|
672
|
+
elif self.num_beams is not None and self.num_return_sequences > self.num_beams:
|
|
639
673
|
raise ValueError(
|
|
640
674
|
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
|
|
641
675
|
f"({self.num_beams})."
|
|
642
676
|
)
|
|
643
677
|
|
|
644
678
|
# 2.5. check cache-related arguments
|
|
645
|
-
if self.use_cache
|
|
679
|
+
if not self.use_cache:
|
|
646
680
|
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
|
|
647
681
|
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
|
|
648
682
|
# (otherwise a user might need to overwrite several parameters).
|
|
649
683
|
no_cache_warning = (
|
|
650
|
-
"You have set `use_cache` to `
|
|
651
|
-
"have no effect."
|
|
684
|
+
"You have not set `use_cache` to `True`, but {cache_arg} is set to {cache_arg_value}."
|
|
685
|
+
"{cache_arg} will have no effect."
|
|
652
686
|
)
|
|
653
687
|
for arg_name in ("cache_implementation", "cache_config"):
|
|
654
688
|
if getattr(self, arg_name) is not None:
|
|
@@ -657,9 +691,9 @@ class GenerationConfig(PushToHubMixin):
|
|
|
657
691
|
)
|
|
658
692
|
|
|
659
693
|
# 2.6. other incorrect combinations
|
|
660
|
-
if self.return_dict_in_generate
|
|
694
|
+
if not self.return_dict_in_generate:
|
|
661
695
|
for extra_output_flag in self.extra_output_flags:
|
|
662
|
-
if getattr(self, extra_output_flag)
|
|
696
|
+
if getattr(self, extra_output_flag):
|
|
663
697
|
minor_issues[extra_output_flag] = (
|
|
664
698
|
f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When "
|
|
665
699
|
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored."
|
|
@@ -675,7 +709,6 @@ class GenerationConfig(PushToHubMixin):
|
|
|
675
709
|
"streamer",
|
|
676
710
|
"negative_prompt_ids",
|
|
677
711
|
"negative_prompt_attention_mask",
|
|
678
|
-
"use_model_defaults",
|
|
679
712
|
)
|
|
680
713
|
for arg in generate_arguments:
|
|
681
714
|
if hasattr(self, arg):
|
|
@@ -1100,7 +1133,7 @@ class GenerationConfig(PushToHubMixin):
|
|
|
1100
1133
|
writer.write(self.to_json_string(use_diff=use_diff, keys_to_pop=keys_to_pop))
|
|
1101
1134
|
|
|
1102
1135
|
@classmethod
|
|
1103
|
-
def from_model_config(cls, model_config: PreTrainedConfig
|
|
1136
|
+
def from_model_config(cls, model_config: Union["PreTrainedConfig", dict]) -> "GenerationConfig":
|
|
1104
1137
|
"""
|
|
1105
1138
|
Instantiates a [`GenerationConfig`] from a [`PreTrainedConfig`]. This function is useful to convert legacy
|
|
1106
1139
|
[`PreTrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
|
|
@@ -1117,23 +1150,28 @@ class GenerationConfig(PushToHubMixin):
|
|
|
1117
1150
|
|
|
1118
1151
|
# Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
|
|
1119
1152
|
config_dict = {key: value for key, value in config_dict.items() if value is not None}
|
|
1120
|
-
|
|
1121
1153
|
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
|
|
1122
1154
|
|
|
1123
1155
|
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
|
1124
1156
|
# generation config (which in turn is defined from the outer attributes of model config).
|
|
1125
|
-
if
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1157
|
+
if isinstance(model_config, dict):
|
|
1158
|
+
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
|
|
1159
|
+
for text_config_name in decoder_possible_text_config_names:
|
|
1160
|
+
if text_config := model_config.get(text_config_name):
|
|
1161
|
+
model_config = text_config
|
|
1162
|
+
break
|
|
1163
|
+
else:
|
|
1164
|
+
model_config = model_config.get_text_config(decoder=True)
|
|
1165
|
+
model_config = model_config.to_dict()
|
|
1166
|
+
|
|
1167
|
+
default_generation_config = GenerationConfig()
|
|
1168
|
+
for attr in generation_config.to_dict():
|
|
1169
|
+
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
|
|
1170
|
+
if attr in model_config and is_unset:
|
|
1171
|
+
setattr(generation_config, attr, model_config[attr])
|
|
1134
1172
|
|
|
1135
1173
|
# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
|
|
1136
|
-
if generation_config.return_dict_in_generate
|
|
1174
|
+
if not generation_config.return_dict_in_generate:
|
|
1137
1175
|
if any(
|
|
1138
1176
|
getattr(generation_config, extra_output_flag, False)
|
|
1139
1177
|
for extra_output_flag in generation_config.extra_output_flags
|
|
@@ -1144,12 +1182,14 @@ class GenerationConfig(PushToHubMixin):
|
|
|
1144
1182
|
generation_config._original_object_hash = hash(generation_config)
|
|
1145
1183
|
return generation_config
|
|
1146
1184
|
|
|
1147
|
-
def update(self, **kwargs):
|
|
1185
|
+
def update(self, defaults_only=False, **kwargs):
|
|
1148
1186
|
"""
|
|
1149
1187
|
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
|
|
1150
1188
|
returning all the unused kwargs.
|
|
1151
1189
|
|
|
1152
1190
|
Args:
|
|
1191
|
+
defaults_only (`bool`, *optional*, defaults to `False`):
|
|
1192
|
+
Whether to update all keys in config with `kwargs` or only those that are set to `None` (i.e. default value).
|
|
1153
1193
|
kwargs (`dict[str, Any]`):
|
|
1154
1194
|
Dictionary of attributes to tentatively update this class.
|
|
1155
1195
|
|
|
@@ -1159,8 +1199,9 @@ class GenerationConfig(PushToHubMixin):
|
|
|
1159
1199
|
to_remove = []
|
|
1160
1200
|
for key, value in kwargs.items():
|
|
1161
1201
|
if hasattr(self, key):
|
|
1162
|
-
|
|
1163
|
-
|
|
1202
|
+
if not defaults_only or getattr(self, key) is None:
|
|
1203
|
+
setattr(self, key, value)
|
|
1204
|
+
to_remove.append(key)
|
|
1164
1205
|
|
|
1165
1206
|
# Confirm that the updated instance is still valid
|
|
1166
1207
|
self.validate()
|