transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -53,6 +53,7 @@ if TYPE_CHECKING:
|
|
|
53
53
|
else:
|
|
54
54
|
VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
|
55
55
|
[
|
|
56
|
+
("ernie4_5_vl_moe", "Ernie4_5_VL_MoeVideoProcessor"),
|
|
56
57
|
("glm46v", "Glm46VVideoProcessor"),
|
|
57
58
|
("glm4v", "Glm4vVideoProcessor"),
|
|
58
59
|
("instructblip", "InstructBlipVideoVideoProcessor"),
|
|
@@ -60,6 +61,8 @@ else:
|
|
|
60
61
|
("internvl", "InternVLVideoProcessor"),
|
|
61
62
|
("llava_next_video", "LlavaNextVideoVideoProcessor"),
|
|
62
63
|
("llava_onevision", "LlavaOnevisionVideoProcessor"),
|
|
64
|
+
("pe_audio_video", "PeVideoVideoProcessor"),
|
|
65
|
+
("pe_video", "PeVideoVideoProcessor"),
|
|
63
66
|
("perception_lm", "PerceptionLMVideoProcessor"),
|
|
64
67
|
("qwen2_5_omni", "Qwen2VLVideoProcessor"),
|
|
65
68
|
("qwen2_5_vl", "Qwen2VLVideoProcessor"),
|
|
@@ -373,9 +376,9 @@ class AutoVideoProcessor:
|
|
|
373
376
|
video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
|
374
377
|
_ = kwargs.pop("code_revision", None)
|
|
375
378
|
video_processor_class.register_for_auto_class()
|
|
376
|
-
return video_processor_class.
|
|
379
|
+
return video_processor_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
377
380
|
elif video_processor_class is not None:
|
|
378
|
-
return video_processor_class.
|
|
381
|
+
return video_processor_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
379
382
|
# Last try: we use the VIDEO_PROCESSOR_MAPPING.
|
|
380
383
|
elif type(config) in VIDEO_PROCESSOR_MAPPING:
|
|
381
384
|
video_processor_class = VIDEO_PROCESSOR_MAPPING[type(config)]
|
|
@@ -903,6 +903,7 @@ class AutoformerEncoder(AutoformerPreTrainedModel):
|
|
|
903
903
|
output_attentions: Optional[bool] = None,
|
|
904
904
|
output_hidden_states: Optional[bool] = None,
|
|
905
905
|
return_dict: Optional[bool] = None,
|
|
906
|
+
**kwargs,
|
|
906
907
|
) -> Union[tuple, BaseModelOutput]:
|
|
907
908
|
r"""
|
|
908
909
|
Args:
|
|
@@ -1024,6 +1025,7 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
|
|
|
1024
1025
|
output_hidden_states: Optional[bool] = None,
|
|
1025
1026
|
return_dict: Optional[bool] = None,
|
|
1026
1027
|
cache_position: Optional[torch.Tensor] = None,
|
|
1028
|
+
**kwargs,
|
|
1027
1029
|
) -> Union[tuple, AutoFormerDecoderOutput]:
|
|
1028
1030
|
r"""
|
|
1029
1031
|
Args:
|
|
@@ -1360,6 +1362,7 @@ class AutoformerModel(AutoformerPreTrainedModel):
|
|
|
1360
1362
|
use_cache: Optional[bool] = None,
|
|
1361
1363
|
return_dict: Optional[bool] = None,
|
|
1362
1364
|
cache_position: Optional[torch.Tensor] = None,
|
|
1365
|
+
**kwargs,
|
|
1363
1366
|
) -> Union[AutoformerModelOutput, tuple]:
|
|
1364
1367
|
r"""
|
|
1365
1368
|
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -1610,6 +1613,7 @@ class AutoformerForPrediction(AutoformerPreTrainedModel):
|
|
|
1610
1613
|
output_attentions: Optional[bool] = None,
|
|
1611
1614
|
use_cache: Optional[bool] = None,
|
|
1612
1615
|
return_dict: Optional[bool] = None,
|
|
1616
|
+
**kwargs,
|
|
1613
1617
|
) -> Union[Seq2SeqTSPredictionOutput, tuple]:
|
|
1614
1618
|
r"""
|
|
1615
1619
|
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -471,6 +471,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
|
|
|
471
471
|
attention_mask=None,
|
|
472
472
|
cache_position=None,
|
|
473
473
|
logits_to_keep=None,
|
|
474
|
+
is_first_iteration=False,
|
|
474
475
|
**kwargs,
|
|
475
476
|
):
|
|
476
477
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -482,12 +483,15 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
|
|
|
482
483
|
attention_mask=attention_mask,
|
|
483
484
|
cache_position=cache_position,
|
|
484
485
|
logits_to_keep=logits_to_keep,
|
|
486
|
+
is_first_iteration=is_first_iteration,
|
|
485
487
|
**kwargs,
|
|
486
488
|
)
|
|
487
489
|
|
|
488
|
-
if
|
|
489
|
-
#
|
|
490
|
-
#
|
|
490
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
491
|
+
# Pixel values are used only in the first iteration if available
|
|
492
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
493
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
494
|
+
# iteration with a question and cached system prompt (continue generate from cache)
|
|
491
495
|
model_inputs["pixel_values"] = pixel_values
|
|
492
496
|
|
|
493
497
|
return model_inputs
|
|
@@ -35,7 +35,7 @@ from transformers.activations import ACT2FN
|
|
|
35
35
|
from ... import initialization as init
|
|
36
36
|
from ...cache_utils import Cache
|
|
37
37
|
from ...generation import GenerationMixin
|
|
38
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
38
|
+
from ...integrations import lazy_load_kernel, use_kernel_forward_from_hub, use_kernelized_func
|
|
39
39
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
|
40
40
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
41
41
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -43,22 +43,10 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
43
43
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
44
44
|
from ...processing_utils import Unpack
|
|
45
45
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
46
|
-
from ...utils.
|
|
46
|
+
from ...utils.generic import maybe_autocast
|
|
47
47
|
from .configuration_bamba import BambaConfig
|
|
48
48
|
|
|
49
49
|
|
|
50
|
-
if is_mamba_2_ssm_available():
|
|
51
|
-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
|
52
|
-
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
|
53
|
-
else:
|
|
54
|
-
selective_state_update = None
|
|
55
|
-
|
|
56
|
-
if is_causal_conv1d_available():
|
|
57
|
-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
58
|
-
else:
|
|
59
|
-
causal_conv1d_update, causal_conv1d_fn = None, None
|
|
60
|
-
|
|
61
|
-
|
|
62
50
|
logger = logging.get_logger(__name__)
|
|
63
51
|
|
|
64
52
|
|
|
@@ -211,7 +199,7 @@ class BambaRotaryEmbedding(nn.Module):
|
|
|
211
199
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
212
200
|
|
|
213
201
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
214
|
-
self.original_inv_freq =
|
|
202
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
215
203
|
|
|
216
204
|
@staticmethod
|
|
217
205
|
def compute_default_rope_parameters(
|
|
@@ -250,7 +238,7 @@ class BambaRotaryEmbedding(nn.Module):
|
|
|
250
238
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
251
239
|
|
|
252
240
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
253
|
-
with
|
|
241
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
254
242
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
255
243
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
256
244
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -345,6 +333,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
345
333
|
return q_embed, k_embed
|
|
346
334
|
|
|
347
335
|
|
|
336
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
348
337
|
class BambaAttention(nn.Module):
|
|
349
338
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
350
339
|
|
|
@@ -370,7 +359,6 @@ class BambaAttention(nn.Module):
|
|
|
370
359
|
self.o_proj = nn.Linear(
|
|
371
360
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
372
361
|
)
|
|
373
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
374
362
|
|
|
375
363
|
def forward(
|
|
376
364
|
self,
|
|
@@ -500,9 +488,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
|
|
|
500
488
|
return hidden_states
|
|
501
489
|
|
|
502
490
|
|
|
503
|
-
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
|
504
|
-
|
|
505
|
-
|
|
506
491
|
# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
|
|
507
492
|
class BambaMixer(nn.Module):
|
|
508
493
|
"""
|
|
@@ -574,6 +559,20 @@ class BambaMixer(nn.Module):
|
|
|
574
559
|
|
|
575
560
|
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
|
|
576
561
|
|
|
562
|
+
global causal_conv1d_update, causal_conv1d_fn
|
|
563
|
+
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
564
|
+
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
|
|
565
|
+
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
|
|
566
|
+
|
|
567
|
+
global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
|
568
|
+
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
|
569
|
+
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
|
|
570
|
+
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
|
|
571
|
+
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
|
|
572
|
+
|
|
573
|
+
global is_fast_path_available
|
|
574
|
+
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
|
575
|
+
|
|
577
576
|
if not is_fast_path_available:
|
|
578
577
|
logger.warning_once(
|
|
579
578
|
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
|
|
@@ -1488,6 +1487,7 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
|
|
|
1488
1487
|
cache_position=None,
|
|
1489
1488
|
position_ids=None,
|
|
1490
1489
|
use_cache=True,
|
|
1490
|
+
is_first_iteration=False,
|
|
1491
1491
|
**kwargs,
|
|
1492
1492
|
):
|
|
1493
1493
|
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
|
|
@@ -1520,7 +1520,7 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
|
|
|
1520
1520
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
1521
1521
|
|
|
1522
1522
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
1523
|
-
if inputs_embeds is not None and
|
|
1523
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
1524
1524
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
1525
1525
|
else:
|
|
1526
1526
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
@@ -43,6 +43,7 @@ from transformers.models.mamba2.modeling_mamba2 import (
|
|
|
43
43
|
)
|
|
44
44
|
|
|
45
45
|
from ... import initialization as init
|
|
46
|
+
from ...integrations import lazy_load_kernel
|
|
46
47
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
|
47
48
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
48
49
|
from ...modeling_utils import PreTrainedModel
|
|
@@ -52,24 +53,9 @@ from ...utils import (
|
|
|
52
53
|
can_return_tuple,
|
|
53
54
|
logging,
|
|
54
55
|
)
|
|
55
|
-
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
|
56
56
|
from .configuration_bamba import BambaConfig
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
if is_mamba_2_ssm_available():
|
|
60
|
-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
|
61
|
-
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
|
62
|
-
else:
|
|
63
|
-
selective_state_update = None
|
|
64
|
-
|
|
65
|
-
if is_causal_conv1d_available():
|
|
66
|
-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
67
|
-
else:
|
|
68
|
-
causal_conv1d_update, causal_conv1d_fn = None, None
|
|
69
|
-
|
|
70
|
-
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
|
71
|
-
|
|
72
|
-
|
|
73
59
|
logger = logging.get_logger(__name__)
|
|
74
60
|
|
|
75
61
|
|
|
@@ -276,6 +262,20 @@ class BambaMixer(nn.Module):
|
|
|
276
262
|
|
|
277
263
|
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
|
|
278
264
|
|
|
265
|
+
global causal_conv1d_update, causal_conv1d_fn
|
|
266
|
+
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
267
|
+
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
|
|
268
|
+
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
|
|
269
|
+
|
|
270
|
+
global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
|
271
|
+
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
|
272
|
+
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
|
|
273
|
+
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
|
|
274
|
+
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
|
|
275
|
+
|
|
276
|
+
global is_fast_path_available
|
|
277
|
+
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
|
278
|
+
|
|
279
279
|
if not is_fast_path_available:
|
|
280
280
|
logger.warning_once(
|
|
281
281
|
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
|
|
@@ -1151,6 +1151,7 @@ class BambaForCausalLM(LlamaForCausalLM):
|
|
|
1151
1151
|
cache_position=None,
|
|
1152
1152
|
position_ids=None,
|
|
1153
1153
|
use_cache=True,
|
|
1154
|
+
is_first_iteration=False,
|
|
1154
1155
|
**kwargs,
|
|
1155
1156
|
):
|
|
1156
1157
|
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
|
|
@@ -1183,7 +1184,7 @@ class BambaForCausalLM(LlamaForCausalLM):
|
|
|
1183
1184
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
1184
1185
|
|
|
1185
1186
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
1186
|
-
if inputs_embeds is not None and
|
|
1187
|
+
if inputs_embeds is not None and is_first_iteration:
|
|
1187
1188
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
1188
1189
|
else:
|
|
1189
1190
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
@@ -23,6 +23,7 @@ import torch
|
|
|
23
23
|
from torch import nn
|
|
24
24
|
from torch.nn import functional as F
|
|
25
25
|
|
|
26
|
+
from ... import initialization as init
|
|
26
27
|
from ...cache_utils import Cache, DynamicCache
|
|
27
28
|
from ...generation import GenerationMixin
|
|
28
29
|
from ...generation.logits_process import (
|
|
@@ -349,6 +350,14 @@ class BarkPreTrainedModel(PreTrainedModel):
|
|
|
349
350
|
|
|
350
351
|
return super().device
|
|
351
352
|
|
|
353
|
+
def _init_weights(self, module):
|
|
354
|
+
super()._init_weights(module)
|
|
355
|
+
if isinstance(module, BarkSelfAttention):
|
|
356
|
+
if module.is_causal:
|
|
357
|
+
block_size = module.config.block_size
|
|
358
|
+
bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size)
|
|
359
|
+
init.copy_(module.bias, bias)
|
|
360
|
+
|
|
352
361
|
|
|
353
362
|
# GPT2-like autoregressive model
|
|
354
363
|
class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
|
|
@@ -426,6 +435,7 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
|
|
|
426
435
|
output_hidden_states: Optional[bool] = None,
|
|
427
436
|
return_dict: Optional[bool] = None,
|
|
428
437
|
cache_position: Optional[torch.Tensor] = None,
|
|
438
|
+
**kwargs,
|
|
429
439
|
) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]:
|
|
430
440
|
r"""
|
|
431
441
|
input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*):
|
|
@@ -1028,6 +1038,7 @@ class BarkFineModel(BarkPreTrainedModel):
|
|
|
1028
1038
|
output_attentions: Optional[bool] = None,
|
|
1029
1039
|
output_hidden_states: Optional[bool] = None,
|
|
1030
1040
|
return_dict: Optional[bool] = None,
|
|
1041
|
+
**kwargs,
|
|
1031
1042
|
) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
|
|
1032
1043
|
r"""
|
|
1033
1044
|
codebook_idx (`int`):
|
|
@@ -23,6 +23,7 @@ import torch
|
|
|
23
23
|
from torch import nn
|
|
24
24
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
25
25
|
|
|
26
|
+
from ... import initialization as init
|
|
26
27
|
from ...activations import ACT2FN
|
|
27
28
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
28
29
|
from ...generation import GenerationMixin
|
|
@@ -476,6 +477,11 @@ class BartPreTrainedModel(PreTrainedModel):
|
|
|
476
477
|
|
|
477
478
|
_can_compile_fullgraph = True
|
|
478
479
|
|
|
480
|
+
def _init_weights(self, module):
|
|
481
|
+
super()._init_weights(module)
|
|
482
|
+
if isinstance(module, BartForConditionalGeneration):
|
|
483
|
+
init.zeros_(module.final_logits_bias)
|
|
484
|
+
|
|
479
485
|
@property
|
|
480
486
|
def dummy_inputs(self):
|
|
481
487
|
pad_token = self.config.pad_token_id
|
|
@@ -547,6 +553,7 @@ class BartEncoder(BartPreTrainedModel):
|
|
|
547
553
|
output_attentions: Optional[bool] = None,
|
|
548
554
|
output_hidden_states: Optional[bool] = None,
|
|
549
555
|
return_dict: Optional[bool] = None,
|
|
556
|
+
**kwargs,
|
|
550
557
|
) -> Union[tuple, BaseModelOutput]:
|
|
551
558
|
r"""
|
|
552
559
|
Args:
|
|
@@ -694,6 +701,7 @@ class BartDecoder(BartPreTrainedModel):
|
|
|
694
701
|
output_hidden_states: Optional[bool] = None,
|
|
695
702
|
return_dict: Optional[bool] = None,
|
|
696
703
|
cache_position: Optional[torch.LongTensor] = None,
|
|
704
|
+
**kwargs,
|
|
697
705
|
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
698
706
|
r"""
|
|
699
707
|
Args:
|
|
@@ -921,6 +929,7 @@ class BartModel(BartPreTrainedModel):
|
|
|
921
929
|
output_hidden_states: Optional[bool] = None,
|
|
922
930
|
return_dict: Optional[bool] = None,
|
|
923
931
|
cache_position: Optional[torch.LongTensor] = None,
|
|
932
|
+
**kwargs,
|
|
924
933
|
) -> Union[tuple, Seq2SeqModelOutput]:
|
|
925
934
|
r"""
|
|
926
935
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1067,6 +1076,7 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
|
|
|
1067
1076
|
output_hidden_states: Optional[bool] = None,
|
|
1068
1077
|
return_dict: Optional[bool] = None,
|
|
1069
1078
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1079
|
+
**kwargs,
|
|
1070
1080
|
) -> Union[tuple, Seq2SeqLMOutput]:
|
|
1071
1081
|
r"""
|
|
1072
1082
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1228,6 +1238,7 @@ class BartForSequenceClassification(BartPreTrainedModel):
|
|
|
1228
1238
|
output_hidden_states: Optional[bool] = None,
|
|
1229
1239
|
return_dict: Optional[bool] = None,
|
|
1230
1240
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1241
|
+
**kwargs,
|
|
1231
1242
|
) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
|
|
1232
1243
|
r"""
|
|
1233
1244
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1360,6 +1371,7 @@ class BartForQuestionAnswering(BartPreTrainedModel):
|
|
|
1360
1371
|
output_hidden_states: Optional[bool] = None,
|
|
1361
1372
|
return_dict: Optional[bool] = None,
|
|
1362
1373
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1374
|
+
**kwargs,
|
|
1363
1375
|
) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]:
|
|
1364
1376
|
r"""
|
|
1365
1377
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1457,6 +1469,7 @@ class BartDecoderWrapper(BartPreTrainedModel):
|
|
|
1457
1469
|
def __init__(self, config):
|
|
1458
1470
|
super().__init__(config)
|
|
1459
1471
|
self.decoder = BartDecoder(config)
|
|
1472
|
+
self.post_init()
|
|
1460
1473
|
|
|
1461
1474
|
def forward(self, *args, **kwargs):
|
|
1462
1475
|
return self.decoder(*args, **kwargs)
|
|
@@ -1505,6 +1518,7 @@ class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
|
|
|
1505
1518
|
return_dict: Optional[bool] = None,
|
|
1506
1519
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1507
1520
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1521
|
+
**kwargs,
|
|
1508
1522
|
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
|
|
1509
1523
|
r"""
|
|
1510
1524
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
# limitations under the License
|
|
15
15
|
"""Tokenization classes for the BARThez model."""
|
|
16
16
|
|
|
17
|
+
from typing import Optional, Union
|
|
18
|
+
|
|
17
19
|
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers
|
|
18
20
|
from tokenizers.models import Unigram
|
|
19
21
|
|
|
@@ -77,7 +79,7 @@ class BarthezTokenizer(TokenizersBackend):
|
|
|
77
79
|
vocab_file (`str`, *optional*):
|
|
78
80
|
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
|
79
81
|
contains the vocabulary necessary to instantiate a tokenizer.
|
|
80
|
-
vocab (`dict`, *optional*):
|
|
82
|
+
vocab (`str`, `dict` or `list`, *optional*):
|
|
81
83
|
Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
|
|
82
84
|
add_prefix_space (`bool`, *optional*, defaults to `True`):
|
|
83
85
|
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
|
@@ -90,6 +92,7 @@ class BarthezTokenizer(TokenizersBackend):
|
|
|
90
92
|
|
|
91
93
|
def __init__(
|
|
92
94
|
self,
|
|
95
|
+
vocab: Optional[Union[str, dict, list]] = None,
|
|
93
96
|
bos_token="<s>",
|
|
94
97
|
eos_token="</s>",
|
|
95
98
|
sep_token="</s>",
|
|
@@ -97,15 +100,12 @@ class BarthezTokenizer(TokenizersBackend):
|
|
|
97
100
|
unk_token="<unk>",
|
|
98
101
|
pad_token="<pad>",
|
|
99
102
|
mask_token="<mask>",
|
|
100
|
-
vocab_file=None,
|
|
101
|
-
vocab=None,
|
|
102
103
|
add_prefix_space=True,
|
|
103
104
|
**kwargs,
|
|
104
105
|
):
|
|
105
106
|
# Mask token behave like a normal word, i.e. include the space before it
|
|
106
107
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
|
107
108
|
self.add_prefix_space = add_prefix_space
|
|
108
|
-
self.vocab_file = vocab_file
|
|
109
109
|
|
|
110
110
|
if vocab is not None:
|
|
111
111
|
self._vocab = vocab
|
|
@@ -122,10 +122,7 @@ class BarthezTokenizer(TokenizersBackend):
|
|
|
122
122
|
|
|
123
123
|
self._tokenizer.normalizer = normalizers.Sequence(
|
|
124
124
|
[
|
|
125
|
-
normalizers.Replace("\n", " "),
|
|
126
|
-
normalizers.Replace("\r", " "),
|
|
127
|
-
normalizers.Replace("\t", " "),
|
|
128
|
-
normalizers.Replace(Regex(r" {2,}"), " "),
|
|
125
|
+
normalizers.Replace(Regex(r"\s{2,}|[\n\r\t]"), " "),
|
|
129
126
|
normalizers.NFC(),
|
|
130
127
|
normalizers.Strip(left=False, right=True),
|
|
131
128
|
]
|
|
@@ -134,9 +131,7 @@ class BarthezTokenizer(TokenizersBackend):
|
|
|
134
131
|
self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme=prepend_scheme)
|
|
135
132
|
self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme=prepend_scheme)
|
|
136
133
|
|
|
137
|
-
tokenizer_object = self._tokenizer
|
|
138
134
|
super().__init__(
|
|
139
|
-
tokenizer_object=tokenizer_object,
|
|
140
135
|
bos_token=bos_token,
|
|
141
136
|
eos_token=eos_token,
|
|
142
137
|
unk_token=unk_token,
|
|
@@ -163,7 +163,6 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
|
|
|
163
163
|
processed_images_grouped[shape] = stacked_images
|
|
164
164
|
|
|
165
165
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
166
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
167
166
|
|
|
168
167
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
169
168
|
|
|
@@ -216,7 +216,7 @@ class BeitPatchEmbeddings(nn.Module):
|
|
|
216
216
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
|
217
217
|
)
|
|
218
218
|
|
|
219
|
-
embeddings = self.projection(pixel_values)
|
|
219
|
+
embeddings = self.projection(pixel_values.to(self.projection.weight.dtype))
|
|
220
220
|
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
|
|
221
221
|
embeddings = embeddings.flatten(2).transpose(1, 2)
|
|
222
222
|
|
|
@@ -726,6 +726,7 @@ class BeitModel(BeitPreTrainedModel):
|
|
|
726
726
|
output_hidden_states: Optional[bool] = None,
|
|
727
727
|
interpolate_pos_encoding: bool = False,
|
|
728
728
|
return_dict: Optional[bool] = None,
|
|
729
|
+
**kwargs,
|
|
729
730
|
) -> Union[tuple, BeitModelOutputWithPooling]:
|
|
730
731
|
r"""
|
|
731
732
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
|
@@ -818,6 +819,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
|
|
|
818
819
|
output_hidden_states: Optional[bool] = None,
|
|
819
820
|
interpolate_pos_encoding: bool = False,
|
|
820
821
|
return_dict: Optional[bool] = None,
|
|
822
|
+
**kwargs,
|
|
821
823
|
) -> Union[tuple, MaskedLMOutput]:
|
|
822
824
|
r"""
|
|
823
825
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
|
@@ -911,6 +913,7 @@ class BeitForImageClassification(BeitPreTrainedModel):
|
|
|
911
913
|
output_hidden_states: Optional[bool] = None,
|
|
912
914
|
interpolate_pos_encoding: bool = False,
|
|
913
915
|
return_dict: Optional[bool] = None,
|
|
916
|
+
**kwargs,
|
|
914
917
|
) -> Union[tuple, ImageClassifierOutput]:
|
|
915
918
|
r"""
|
|
916
919
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -1244,6 +1247,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
|
|
1244
1247
|
output_hidden_states: Optional[bool] = None,
|
|
1245
1248
|
interpolate_pos_encoding: bool = False,
|
|
1246
1249
|
return_dict: Optional[bool] = None,
|
|
1250
|
+
**kwargs,
|
|
1247
1251
|
) -> Union[tuple, SemanticSegmenterOutput]:
|
|
1248
1252
|
r"""
|
|
1249
1253
|
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
|
@@ -1371,6 +1375,7 @@ class BeitBackbone(BeitPreTrainedModel, BackboneMixin):
|
|
|
1371
1375
|
output_hidden_states: Optional[bool] = None,
|
|
1372
1376
|
output_attentions: Optional[bool] = None,
|
|
1373
1377
|
return_dict: Optional[bool] = None,
|
|
1378
|
+
**kwargs,
|
|
1374
1379
|
) -> BackboneOutput:
|
|
1375
1380
|
r"""
|
|
1376
1381
|
Examples:
|
|
@@ -569,6 +569,9 @@ class BertPreTrainedModel(PreTrainedModel):
|
|
|
569
569
|
super()._init_weights(module)
|
|
570
570
|
if isinstance(module, BertLMPredictionHead):
|
|
571
571
|
init.zeros_(module.bias)
|
|
572
|
+
elif isinstance(module, BertEmbeddings):
|
|
573
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
574
|
+
init.zeros_(module.token_type_ids)
|
|
572
575
|
|
|
573
576
|
|
|
574
577
|
@dataclass
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Tokenization classes for Bert."""
|
|
16
16
|
|
|
17
17
|
import collections
|
|
18
|
-
from typing import Optional
|
|
18
|
+
from typing import Optional, Union
|
|
19
19
|
|
|
20
20
|
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
|
21
21
|
from tokenizers.models import WordPiece
|
|
@@ -48,8 +48,8 @@ class BertTokenizer(TokenizersBackend):
|
|
|
48
48
|
this superclass for more information regarding those methods.
|
|
49
49
|
|
|
50
50
|
Args:
|
|
51
|
-
|
|
52
|
-
|
|
51
|
+
vocab (`str` or `dict[str, int]`, *optional*):
|
|
52
|
+
Custom vocabulary dictionary. If not provided, vocabulary is loaded from `vocab_file`.
|
|
53
53
|
do_lower_case (`bool`, *optional*, defaults to `False`):
|
|
54
54
|
Whether or not to lowercase the input when tokenizing.
|
|
55
55
|
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
|
|
@@ -72,17 +72,15 @@ class BertTokenizer(TokenizersBackend):
|
|
|
72
72
|
strip_accents (`bool`, *optional*):
|
|
73
73
|
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
|
74
74
|
value for `lowercase` (as in the original BERT).
|
|
75
|
-
vocab (`dict`, *optional*):
|
|
76
|
-
Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
|
|
77
75
|
"""
|
|
78
76
|
|
|
79
77
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
80
78
|
model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
|
|
81
|
-
|
|
79
|
+
model = WordPiece
|
|
82
80
|
|
|
83
81
|
def __init__(
|
|
84
82
|
self,
|
|
85
|
-
|
|
83
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
86
84
|
do_lower_case: bool = False,
|
|
87
85
|
unk_token: str = "[UNK]",
|
|
88
86
|
sep_token: str = "[SEP]",
|
|
@@ -91,28 +89,21 @@ class BertTokenizer(TokenizersBackend):
|
|
|
91
89
|
mask_token: str = "[MASK]",
|
|
92
90
|
tokenize_chinese_chars: bool = True,
|
|
93
91
|
strip_accents: Optional[bool] = None,
|
|
94
|
-
vocab: Optional[dict] = None,
|
|
95
92
|
**kwargs,
|
|
96
93
|
):
|
|
97
94
|
self.do_lower_case = do_lower_case
|
|
98
95
|
self.tokenize_chinese_chars = tokenize_chinese_chars
|
|
99
96
|
self.strip_accents = strip_accents
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
self._vocab = (
|
|
103
|
-
{token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
|
|
104
|
-
)
|
|
105
|
-
else:
|
|
106
|
-
self._vocab = {
|
|
97
|
+
if vocab is None:
|
|
98
|
+
vocab = {
|
|
107
99
|
str(pad_token): 0,
|
|
108
100
|
str(unk_token): 1,
|
|
109
101
|
str(cls_token): 2,
|
|
110
102
|
str(sep_token): 3,
|
|
111
103
|
str(mask_token): 4,
|
|
112
104
|
}
|
|
113
|
-
|
|
105
|
+
self._vocab = vocab
|
|
114
106
|
self._tokenizer = Tokenizer(WordPiece(self._vocab, unk_token=str(unk_token)))
|
|
115
|
-
|
|
116
107
|
self._tokenizer.normalizer = normalizers.BertNormalizer(
|
|
117
108
|
clean_text=True,
|
|
118
109
|
handle_chinese_chars=tokenize_chinese_chars,
|
|
@@ -121,11 +112,7 @@ class BertTokenizer(TokenizersBackend):
|
|
|
121
112
|
)
|
|
122
113
|
self._tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
|
123
114
|
self._tokenizer.decoder = decoders.WordPiece(prefix="##")
|
|
124
|
-
|
|
125
|
-
tokenizer_object = self._tokenizer
|
|
126
|
-
|
|
127
115
|
super().__init__(
|
|
128
|
-
tokenizer_object=tokenizer_object,
|
|
129
116
|
do_lower_case=do_lower_case,
|
|
130
117
|
unk_token=unk_token,
|
|
131
118
|
sep_token=sep_token,
|
|
@@ -463,6 +463,8 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
|
|
|
463
463
|
super()._init_weights(module)
|
|
464
464
|
if isinstance(module, BertGenerationOnlyLMHead):
|
|
465
465
|
init.zeros_(module.bias)
|
|
466
|
+
elif isinstance(module, BertGenerationEmbeddings):
|
|
467
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
466
468
|
|
|
467
469
|
|
|
468
470
|
@auto_docstring(
|