transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -422,6 +422,8 @@ class RagModel(RagPreTrainedModel):
|
|
|
422
422
|
self.ctx_encoder = None
|
|
423
423
|
self.context_encoder_training = False
|
|
424
424
|
|
|
425
|
+
self.post_init()
|
|
426
|
+
|
|
425
427
|
@auto_docstring
|
|
426
428
|
def forward(
|
|
427
429
|
self,
|
|
@@ -439,6 +441,7 @@ class RagModel(RagPreTrainedModel):
|
|
|
439
441
|
output_hidden_states: Optional[bool] = None,
|
|
440
442
|
output_retrieved: Optional[bool] = None,
|
|
441
443
|
n_docs: Optional[int] = None,
|
|
444
|
+
**kwargs,
|
|
442
445
|
) -> Union[tuple[torch.Tensor], RetrievAugLMOutput]:
|
|
443
446
|
r"""
|
|
444
447
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -689,6 +692,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|
|
689
692
|
# instantiate model
|
|
690
693
|
self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)
|
|
691
694
|
|
|
695
|
+
self.post_init()
|
|
696
|
+
|
|
692
697
|
def set_retriever(self, retriever: RagRetriever):
|
|
693
698
|
self.rag.retriever = retriever
|
|
694
699
|
|
|
@@ -1125,6 +1130,8 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
|
|
|
1125
1130
|
# instantiate model
|
|
1126
1131
|
self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)
|
|
1127
1132
|
|
|
1133
|
+
self.post_init()
|
|
1134
|
+
|
|
1128
1135
|
def set_retriever(self, retriever: RagRetriever):
|
|
1129
1136
|
self.rag.retriever = retriever
|
|
1130
1137
|
|
|
@@ -1403,7 +1410,6 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
|
|
|
1403
1410
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
|
1404
1411
|
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
|
1405
1412
|
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
|
|
1406
|
-
use_model_defaults: Optional[bool] = None,
|
|
1407
1413
|
**kwargs,
|
|
1408
1414
|
) -> torch.LongTensor:
|
|
1409
1415
|
"""
|
|
@@ -1462,11 +1468,6 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
|
|
|
1462
1468
|
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
|
1463
1469
|
model's config. If a stopping criteria is passed that is already created with the arguments or a
|
|
1464
1470
|
model's config an error is thrown.
|
|
1465
|
-
use_model_defaults (`bool`, *optional*):
|
|
1466
|
-
When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
|
|
1467
|
-
generation configuration (`model.generation_config`), as opposed to the global defaults
|
|
1468
|
-
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
|
|
1469
|
-
`True`.
|
|
1470
1471
|
kwargs (`dict[str, Any]`, *optional*):
|
|
1471
1472
|
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
|
1472
1473
|
forwarded to the `forward` function of the model.
|
|
@@ -1478,9 +1479,7 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
|
|
|
1478
1479
|
"""
|
|
1479
1480
|
# Handle `generation_config` and kwargs that might update it
|
|
1480
1481
|
generation_mode_kwargs = self._extract_generation_mode_kwargs(None, kwargs, False, None, None)
|
|
1481
|
-
generation_config, model_kwargs = self._prepare_generation_config(
|
|
1482
|
-
generation_config, use_model_defaults, **kwargs
|
|
1483
|
-
)
|
|
1482
|
+
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
|
1484
1483
|
generation_mode = generation_config.get_generation_mode()
|
|
1485
1484
|
if generation_mode not in [
|
|
1486
1485
|
GenerationMode.SAMPLE,
|
|
@@ -31,6 +31,7 @@ from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
|
|
|
31
31
|
from ...modeling_rope_utils import dynamic_rope_update
|
|
32
32
|
from ...modeling_utils import PreTrainedModel
|
|
33
33
|
from ...utils import auto_docstring, logging
|
|
34
|
+
from ...utils.generic import maybe_autocast
|
|
34
35
|
from ...utils.import_utils import is_tracing
|
|
35
36
|
from .configuration_recurrent_gemma import RecurrentGemmaConfig
|
|
36
37
|
|
|
@@ -79,7 +80,7 @@ class RecurrentGemmaRotaryEmbedding(nn.Module):
|
|
|
79
80
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
80
81
|
|
|
81
82
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
82
|
-
self.original_inv_freq =
|
|
83
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
83
84
|
|
|
84
85
|
@staticmethod
|
|
85
86
|
# Ignore copy
|
|
@@ -121,7 +122,7 @@ class RecurrentGemmaRotaryEmbedding(nn.Module):
|
|
|
121
122
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
122
123
|
|
|
123
124
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
124
|
-
with
|
|
125
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
125
126
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
126
127
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
127
128
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -460,6 +461,7 @@ class RecurrentGemmaRecurrentBlock(nn.Module):
|
|
|
460
461
|
use_cache: bool = True,
|
|
461
462
|
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
462
463
|
_, seq_len, _ = input_states.shape
|
|
464
|
+
batch_size = input_states.shape[0]
|
|
463
465
|
|
|
464
466
|
y_branch = self.linear_y(input_states)
|
|
465
467
|
y_branch = self.act_fn(y_branch)
|
|
@@ -468,6 +470,17 @@ class RecurrentGemmaRecurrentBlock(nn.Module):
|
|
|
468
470
|
x_branch = x_branch.transpose(1, 2)
|
|
469
471
|
|
|
470
472
|
if use_cache:
|
|
473
|
+
# Check if cache needs initialization (None or batch size mismatch)
|
|
474
|
+
if self.conv1d_state is None or self.conv1d_state.shape[0] != batch_size:
|
|
475
|
+
self.conv1d_state = torch.zeros(
|
|
476
|
+
(batch_size, self.hidden_size, self.conv1d_width - 1),
|
|
477
|
+
device=input_states.device,
|
|
478
|
+
dtype=input_states.dtype,
|
|
479
|
+
)
|
|
480
|
+
self.rg_lru.recurrent_states = torch.zeros(
|
|
481
|
+
(batch_size, self.lru_width), device=input_states.device, dtype=torch.float32
|
|
482
|
+
)
|
|
483
|
+
|
|
471
484
|
if cache_position.shape[0] != 1: # prefill
|
|
472
485
|
self.conv1d_state = nn.functional.pad(x_branch, (self.conv1d_width - x_branch.shape[-1] - 1, 0))
|
|
473
486
|
x_branch = self.conv_1d(x_branch)[..., :seq_len]
|
|
@@ -598,10 +611,11 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
|
|
|
598
611
|
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
599
612
|
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
|
|
600
613
|
init.zeros_(module.weight[module.padding_idx])
|
|
601
|
-
|
|
602
614
|
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
|
603
615
|
elif isinstance(module, RecurrentGemmaRMSNorm):
|
|
604
616
|
init.zeros_(module.weight)
|
|
617
|
+
elif isinstance(module, RecurrentGemmaModel):
|
|
618
|
+
init.constant_(module.normalizer, module.config.hidden_size**0.5)
|
|
605
619
|
|
|
606
620
|
def _setup_cache(self, config, batch, device, dtype):
|
|
607
621
|
layers = getattr(self, "model", self).layers
|
|
@@ -643,6 +657,7 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
|
|
|
643
657
|
use_cache: Optional[bool] = None,
|
|
644
658
|
output_hidden_states: Optional[bool] = None,
|
|
645
659
|
return_dict: Optional[bool] = None,
|
|
660
|
+
**kwargs,
|
|
646
661
|
) -> Union[tuple, BaseModelOutputWithNoAttention]:
|
|
647
662
|
output_hidden_states = (
|
|
648
663
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -1851,6 +1851,14 @@ class ReformerPreTrainedModel(PreTrainedModel):
|
|
|
1851
1851
|
if isinstance(module, AxialPositionEmbeddings):
|
|
1852
1852
|
for weight in module.weights:
|
|
1853
1853
|
init.normal_(weight, std=self.config.axial_norm_std)
|
|
1854
|
+
elif isinstance(module, LSHSelfAttention):
|
|
1855
|
+
init.constant_(module.self_mask_value_float16, -1e3)
|
|
1856
|
+
init.constant_(module.self_mask_value_float32, -1e5)
|
|
1857
|
+
init.constant_(module.mask_value_float16, -1e4)
|
|
1858
|
+
init.constant_(module.mask_value_float32, -1e9)
|
|
1859
|
+
elif isinstance(module, LocalSelfAttention):
|
|
1860
|
+
init.constant_(module.mask_value_float16, -1e4)
|
|
1861
|
+
init.constant_(module.mask_value_float32, -1e9)
|
|
1854
1862
|
|
|
1855
1863
|
|
|
1856
1864
|
@dataclass
|
|
@@ -1946,6 +1954,7 @@ class ReformerModel(ReformerPreTrainedModel):
|
|
|
1946
1954
|
output_hidden_states: Optional[bool] = None,
|
|
1947
1955
|
output_attentions: Optional[bool] = None,
|
|
1948
1956
|
return_dict: Optional[bool] = None,
|
|
1957
|
+
**kwargs,
|
|
1949
1958
|
) -> Union[tuple, ReformerModelOutput]:
|
|
1950
1959
|
r"""
|
|
1951
1960
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -2238,7 +2247,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel, GenerationMixin):
|
|
|
2238
2247
|
)
|
|
2239
2248
|
|
|
2240
2249
|
def prepare_inputs_for_generation(
|
|
2241
|
-
self, input_ids, past_key_values=None, use_cache=None, num_hashes=None, **kwargs
|
|
2250
|
+
self, input_ids, past_key_values=None, use_cache=None, num_hashes=None, is_first_iteration=False, **kwargs
|
|
2242
2251
|
):
|
|
2243
2252
|
# Overitten -- different expected inputs/outputs
|
|
2244
2253
|
|
|
@@ -2297,6 +2306,7 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
|
|
|
2297
2306
|
output_hidden_states: Optional[bool] = None,
|
|
2298
2307
|
output_attentions: Optional[bool] = None,
|
|
2299
2308
|
return_dict: Optional[bool] = None,
|
|
2309
|
+
**kwargs,
|
|
2300
2310
|
) -> Union[tuple, MaskedLMOutput]:
|
|
2301
2311
|
r"""
|
|
2302
2312
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -2428,6 +2438,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
|
|
|
2428
2438
|
output_hidden_states: Optional[bool] = None,
|
|
2429
2439
|
output_attentions: Optional[bool] = None,
|
|
2430
2440
|
return_dict: Optional[bool] = None,
|
|
2441
|
+
**kwargs,
|
|
2431
2442
|
) -> Union[tuple, SequenceClassifierOutput]:
|
|
2432
2443
|
r"""
|
|
2433
2444
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -2577,6 +2588,7 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
|
|
|
2577
2588
|
output_hidden_states: Optional[bool] = None,
|
|
2578
2589
|
output_attentions: Optional[bool] = None,
|
|
2579
2590
|
return_dict: Optional[bool] = None,
|
|
2591
|
+
**kwargs,
|
|
2580
2592
|
) -> Union[tuple, QuestionAnsweringModelOutput]:
|
|
2581
2593
|
r"""
|
|
2582
2594
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Tokenization class for model Reformer."""
|
|
16
16
|
|
|
17
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers
|
|
20
20
|
from tokenizers.models import BPE
|
|
@@ -60,38 +60,27 @@ class ReformerTokenizer(TokenizersBackend):
|
|
|
60
60
|
The token used for padding, for example when batching sequences of different lengths.
|
|
61
61
|
additional_special_tokens (`list[str]`, *optional*):
|
|
62
62
|
Additional special tokens used by the tokenizer.
|
|
63
|
-
vocab (`dict`, *optional*):
|
|
64
|
-
Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file
|
|
65
|
-
merges (`list`, *optional*):
|
|
66
|
-
Custom merges list. If not provided, merges are loaded from vocab_file
|
|
63
|
+
vocab (`str` or `dict[str, int]`, *optional*):
|
|
64
|
+
Custom vocabulary dictionary. If not provided, vocabulary is loaded from `vocab_file`.
|
|
65
|
+
merges (`str` or `list[str]`, *optional*):
|
|
66
|
+
Custom merges list. If not provided, merges are loaded from `vocab_file`.
|
|
67
67
|
"""
|
|
68
68
|
|
|
69
69
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
70
70
|
model_input_names = ["input_ids", "attention_mask"]
|
|
71
|
-
|
|
71
|
+
model = BPE
|
|
72
72
|
|
|
73
73
|
def __init__(
|
|
74
74
|
self,
|
|
75
|
-
|
|
75
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
76
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
76
77
|
eos_token: str = "</s>",
|
|
77
78
|
unk_token: str = "<unk>",
|
|
78
79
|
additional_special_tokens: Optional[list] = None,
|
|
79
|
-
vocab: Optional[dict] = None,
|
|
80
|
-
merges: Optional[list] = None,
|
|
81
80
|
**kwargs,
|
|
82
81
|
):
|
|
83
|
-
self.
|
|
84
|
-
|
|
85
|
-
if vocab is not None:
|
|
86
|
-
self._vocab = vocab
|
|
87
|
-
else:
|
|
88
|
-
self._vocab = {}
|
|
89
|
-
|
|
90
|
-
if merges is not None:
|
|
91
|
-
# Convert lists to tuples if necessary (happens when loading from JSON)
|
|
92
|
-
self._merges = [tuple(merge) if isinstance(merge, list) else merge for merge in merges]
|
|
93
|
-
else:
|
|
94
|
-
self._merges = []
|
|
82
|
+
self._vocab = vocab or {}
|
|
83
|
+
self._merges = merges or []
|
|
95
84
|
|
|
96
85
|
self._tokenizer = Tokenizer(
|
|
97
86
|
BPE(
|
|
@@ -106,10 +95,7 @@ class ReformerTokenizer(TokenizersBackend):
|
|
|
106
95
|
|
|
107
96
|
self._tokenizer.normalizer = normalizers.Sequence(
|
|
108
97
|
[
|
|
109
|
-
normalizers.Replace("\n", " "),
|
|
110
|
-
normalizers.Replace("\r", " "),
|
|
111
|
-
normalizers.Replace("\t", " "),
|
|
112
|
-
normalizers.Replace(Regex(r" {2,}"), " "),
|
|
98
|
+
normalizers.Replace(Regex(r"\s{2,}|[\n\r\t]"), " "),
|
|
113
99
|
normalizers.NFC(),
|
|
114
100
|
normalizers.Strip(left=False, right=True),
|
|
115
101
|
]
|
|
@@ -118,10 +104,7 @@ class ReformerTokenizer(TokenizersBackend):
|
|
|
118
104
|
self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always")
|
|
119
105
|
self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always")
|
|
120
106
|
|
|
121
|
-
tokenizer_object = self._tokenizer
|
|
122
|
-
|
|
123
107
|
super().__init__(
|
|
124
|
-
tokenizer_object=tokenizer_object,
|
|
125
108
|
eos_token=eos_token,
|
|
126
109
|
unk_token=unk_token,
|
|
127
110
|
additional_special_tokens=additional_special_tokens or [],
|
|
@@ -278,6 +278,10 @@ class RegNetPreTrainedModel(PreTrainedModel):
|
|
|
278
278
|
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
279
279
|
init.constant_(module.weight, 1)
|
|
280
280
|
init.constant_(module.bias, 0)
|
|
281
|
+
if getattr(module, "running_mean", None) is not None:
|
|
282
|
+
init.zeros_(module.running_mean)
|
|
283
|
+
init.ones_(module.running_var)
|
|
284
|
+
init.zeros_(module.num_batches_tracked)
|
|
281
285
|
|
|
282
286
|
|
|
283
287
|
@auto_docstring
|
|
@@ -294,7 +298,11 @@ class RegNetModel(RegNetPreTrainedModel):
|
|
|
294
298
|
|
|
295
299
|
@auto_docstring
|
|
296
300
|
def forward(
|
|
297
|
-
self,
|
|
301
|
+
self,
|
|
302
|
+
pixel_values: Tensor,
|
|
303
|
+
output_hidden_states: Optional[bool] = None,
|
|
304
|
+
return_dict: Optional[bool] = None,
|
|
305
|
+
**kwargs,
|
|
298
306
|
) -> BaseModelOutputWithPoolingAndNoAttention:
|
|
299
307
|
output_hidden_states = (
|
|
300
308
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -348,6 +356,7 @@ class RegNetForImageClassification(RegNetPreTrainedModel):
|
|
|
348
356
|
labels: Optional[torch.LongTensor] = None,
|
|
349
357
|
output_hidden_states: Optional[bool] = None,
|
|
350
358
|
return_dict: Optional[bool] = None,
|
|
359
|
+
**kwargs,
|
|
351
360
|
) -> ImageClassifierOutputWithNoAttention:
|
|
352
361
|
r"""
|
|
353
362
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -21,6 +21,7 @@ import torch
|
|
|
21
21
|
from torch import nn
|
|
22
22
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
23
23
|
|
|
24
|
+
from ... import initialization as init
|
|
24
25
|
from ...activations import ACT2FN
|
|
25
26
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
26
27
|
from ...generation import GenerationMixin
|
|
@@ -488,6 +489,11 @@ class RemBertPreTrainedModel(PreTrainedModel):
|
|
|
488
489
|
base_model_prefix = "rembert"
|
|
489
490
|
supports_gradient_checkpointing = True
|
|
490
491
|
|
|
492
|
+
def _init_weights(self, module):
|
|
493
|
+
super()._init_weights(module)
|
|
494
|
+
if isinstance(module, RemBertEmbeddings):
|
|
495
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
496
|
+
|
|
491
497
|
|
|
492
498
|
@auto_docstring(
|
|
493
499
|
custom_intro="""
|
|
@@ -540,6 +546,7 @@ class RemBertModel(RemBertPreTrainedModel):
|
|
|
540
546
|
output_hidden_states: Optional[bool] = None,
|
|
541
547
|
return_dict: Optional[bool] = None,
|
|
542
548
|
cache_position: Optional[torch.Tensor] = None,
|
|
549
|
+
**kwargs,
|
|
543
550
|
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
|
544
551
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
545
552
|
output_hidden_states = (
|
|
@@ -659,6 +666,7 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
|
|
|
659
666
|
output_attentions: Optional[bool] = None,
|
|
660
667
|
output_hidden_states: Optional[bool] = None,
|
|
661
668
|
return_dict: Optional[bool] = None,
|
|
669
|
+
**kwargs,
|
|
662
670
|
) -> Union[tuple, MaskedLMOutput]:
|
|
663
671
|
r"""
|
|
664
672
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -700,7 +708,7 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
|
|
|
700
708
|
attentions=outputs.attentions,
|
|
701
709
|
)
|
|
702
710
|
|
|
703
|
-
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
|
711
|
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, is_first_iteration=False, **model_kwargs):
|
|
704
712
|
input_shape = input_ids.shape
|
|
705
713
|
effective_batch_size = input_shape[0]
|
|
706
714
|
|
|
@@ -857,6 +865,7 @@ class RemBertForSequenceClassification(RemBertPreTrainedModel):
|
|
|
857
865
|
output_attentions: Optional[bool] = None,
|
|
858
866
|
output_hidden_states: Optional[bool] = None,
|
|
859
867
|
return_dict: Optional[bool] = None,
|
|
868
|
+
**kwargs,
|
|
860
869
|
) -> Union[tuple, SequenceClassifierOutput]:
|
|
861
870
|
r"""
|
|
862
871
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -940,6 +949,7 @@ class RemBertForMultipleChoice(RemBertPreTrainedModel):
|
|
|
940
949
|
output_attentions: Optional[bool] = None,
|
|
941
950
|
output_hidden_states: Optional[bool] = None,
|
|
942
951
|
return_dict: Optional[bool] = None,
|
|
952
|
+
**kwargs,
|
|
943
953
|
) -> Union[tuple, MultipleChoiceModelOutput]:
|
|
944
954
|
r"""
|
|
945
955
|
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
|
@@ -1043,6 +1053,7 @@ class RemBertForTokenClassification(RemBertPreTrainedModel):
|
|
|
1043
1053
|
output_attentions: Optional[bool] = None,
|
|
1044
1054
|
output_hidden_states: Optional[bool] = None,
|
|
1045
1055
|
return_dict: Optional[bool] = None,
|
|
1056
|
+
**kwargs,
|
|
1046
1057
|
) -> Union[tuple, TokenClassifierOutput]:
|
|
1047
1058
|
r"""
|
|
1048
1059
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1109,6 +1120,7 @@ class RemBertForQuestionAnswering(RemBertPreTrainedModel):
|
|
|
1109
1120
|
output_attentions: Optional[bool] = None,
|
|
1110
1121
|
output_hidden_states: Optional[bool] = None,
|
|
1111
1122
|
return_dict: Optional[bool] = None,
|
|
1123
|
+
**kwargs,
|
|
1112
1124
|
) -> Union[tuple, QuestionAnsweringModelOutput]:
|
|
1113
1125
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1114
1126
|
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Tokenization classes for RemBert model."""
|
|
16
16
|
|
|
17
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
|
20
20
|
from tokenizers.models import Unigram
|
|
@@ -74,11 +74,11 @@ class RemBertTokenizer(TokenizersBackend):
|
|
|
74
74
|
|
|
75
75
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
76
76
|
model_input_names = ["input_ids", "attention_mask"]
|
|
77
|
-
|
|
77
|
+
model = Unigram
|
|
78
78
|
|
|
79
79
|
def __init__(
|
|
80
80
|
self,
|
|
81
|
-
|
|
81
|
+
vocab: Optional[Union[str, list[tuple[str, float]]]] = None,
|
|
82
82
|
do_lower_case: bool = False,
|
|
83
83
|
keep_accents: bool = False,
|
|
84
84
|
bos_token: str = "[CLS]",
|
|
@@ -90,11 +90,8 @@ class RemBertTokenizer(TokenizersBackend):
|
|
|
90
90
|
mask_token: str = "[MASK]",
|
|
91
91
|
add_prefix_space: bool = True,
|
|
92
92
|
remove_space: bool = True,
|
|
93
|
-
vocab: Optional[dict] = None,
|
|
94
|
-
merges: Optional[list] = None,
|
|
95
93
|
**kwargs,
|
|
96
94
|
):
|
|
97
|
-
self.vocab_file = vocab_file
|
|
98
95
|
self.remove_space = remove_space
|
|
99
96
|
self.do_lower_case = do_lower_case
|
|
100
97
|
self.keep_accents = keep_accents
|
|
@@ -147,11 +144,7 @@ class RemBertTokenizer(TokenizersBackend):
|
|
|
147
144
|
self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme=prepend_scheme)
|
|
148
145
|
|
|
149
146
|
self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme=prepend_scheme)
|
|
150
|
-
|
|
151
|
-
tokenizer_object = self._tokenizer
|
|
152
|
-
|
|
153
147
|
super().__init__(
|
|
154
|
-
tokenizer_object=tokenizer_object,
|
|
155
148
|
add_prefix_space=add_prefix_space,
|
|
156
149
|
do_lower_case=do_lower_case,
|
|
157
150
|
keep_accents=keep_accents,
|
|
@@ -262,9 +262,14 @@ class ResNetPreTrainedModel(PreTrainedModel):
|
|
|
262
262
|
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
|
|
263
263
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
264
264
|
init.uniform_(module.bias, -bound, bound)
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
init.
|
|
265
|
+
# We need to check it like that as some Detr models replace the BatchNorm2d by their own
|
|
266
|
+
elif "BatchNorm" in module.__class__.__name__:
|
|
267
|
+
init.ones_(module.weight)
|
|
268
|
+
init.zeros_(module.bias)
|
|
269
|
+
init.zeros_(module.running_mean)
|
|
270
|
+
init.ones_(module.running_var)
|
|
271
|
+
if getattr(module, "num_batches_tracked", None) is not None:
|
|
272
|
+
init.zeros_(module.num_batches_tracked)
|
|
268
273
|
|
|
269
274
|
|
|
270
275
|
@auto_docstring
|
|
@@ -280,7 +285,11 @@ class ResNetModel(ResNetPreTrainedModel):
|
|
|
280
285
|
|
|
281
286
|
@auto_docstring
|
|
282
287
|
def forward(
|
|
283
|
-
self,
|
|
288
|
+
self,
|
|
289
|
+
pixel_values: Tensor,
|
|
290
|
+
output_hidden_states: Optional[bool] = None,
|
|
291
|
+
return_dict: Optional[bool] = None,
|
|
292
|
+
**kwargs,
|
|
284
293
|
) -> BaseModelOutputWithPoolingAndNoAttention:
|
|
285
294
|
output_hidden_states = (
|
|
286
295
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -333,6 +342,7 @@ class ResNetForImageClassification(ResNetPreTrainedModel):
|
|
|
333
342
|
labels: Optional[torch.LongTensor] = None,
|
|
334
343
|
output_hidden_states: Optional[bool] = None,
|
|
335
344
|
return_dict: Optional[bool] = None,
|
|
345
|
+
**kwargs,
|
|
336
346
|
) -> ImageClassifierOutputWithNoAttention:
|
|
337
347
|
r"""
|
|
338
348
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -380,7 +390,11 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
|
|
|
380
390
|
|
|
381
391
|
@auto_docstring
|
|
382
392
|
def forward(
|
|
383
|
-
self,
|
|
393
|
+
self,
|
|
394
|
+
pixel_values: Tensor,
|
|
395
|
+
output_hidden_states: Optional[bool] = None,
|
|
396
|
+
return_dict: Optional[bool] = None,
|
|
397
|
+
**kwargs,
|
|
384
398
|
) -> BackboneOutput:
|
|
385
399
|
r"""
|
|
386
400
|
Examples:
|
|
@@ -501,6 +501,9 @@ class RobertaPreTrainedModel(PreTrainedModel):
|
|
|
501
501
|
super()._init_weights(module)
|
|
502
502
|
if isinstance(module, RobertaLMHead):
|
|
503
503
|
init.zeros_(module.bias)
|
|
504
|
+
elif isinstance(module, RobertaEmbeddings):
|
|
505
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
506
|
+
init.zeros_(module.token_type_ids)
|
|
504
507
|
|
|
505
508
|
|
|
506
509
|
class RobertaEncoder(nn.Module):
|
|
@@ -172,6 +172,9 @@ class RobertaPreTrainedModel(PreTrainedModel):
|
|
|
172
172
|
super()._init_weights(module)
|
|
173
173
|
if isinstance(module, RobertaLMHead):
|
|
174
174
|
init.zeros_(module.bias)
|
|
175
|
+
elif isinstance(module, RobertaEmbeddings):
|
|
176
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
177
|
+
init.zeros_(module.token_type_ids)
|
|
175
178
|
|
|
176
179
|
|
|
177
180
|
class RobertaModel(BertModel):
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Tokenization classes for RoBERTa."""
|
|
16
16
|
|
|
17
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
from tokenizers import Tokenizer, decoders, pre_tokenizers, processors
|
|
20
20
|
from tokenizers.models import BPE
|
|
@@ -59,6 +59,10 @@ class RobertaTokenizer(TokenizersBackend):
|
|
|
59
59
|
this superclass for more information regarding those methods.
|
|
60
60
|
|
|
61
61
|
Args:
|
|
62
|
+
vocab (`str`, `dict` or `list`, *optional*):
|
|
63
|
+
Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
|
|
64
|
+
merges (`str` or `list`, *optional*):
|
|
65
|
+
Custom merges list. If not provided, merges are loaded from merges_file.
|
|
62
66
|
errors (`str`, *optional*, defaults to `"replace"`):
|
|
63
67
|
Paradigm to follow when decoding bytes to UTF-8. See
|
|
64
68
|
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
|
@@ -102,18 +106,16 @@ class RobertaTokenizer(TokenizersBackend):
|
|
|
102
106
|
other word. (RoBERTa tokenizer detect beginning of words by the preceding space).
|
|
103
107
|
trim_offsets (`bool`, *optional*, defaults to `True`):
|
|
104
108
|
Whether the post processing step should trim offsets to avoid including whitespaces.
|
|
105
|
-
vocab (`dict`, *optional*):
|
|
106
|
-
Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
|
|
107
|
-
merges (`list`, *optional*):
|
|
108
|
-
Custom merges list. If not provided, merges are loaded from merges_file.
|
|
109
109
|
"""
|
|
110
110
|
|
|
111
111
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
112
112
|
model_input_names = ["input_ids", "attention_mask"]
|
|
113
|
-
|
|
113
|
+
model = BPE
|
|
114
114
|
|
|
115
115
|
def __init__(
|
|
116
116
|
self,
|
|
117
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
118
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
117
119
|
errors: str = "replace",
|
|
118
120
|
bos_token: str = "<s>",
|
|
119
121
|
eos_token: str = "</s>",
|
|
@@ -124,30 +126,22 @@ class RobertaTokenizer(TokenizersBackend):
|
|
|
124
126
|
mask_token: str = "<mask>",
|
|
125
127
|
add_prefix_space: bool = False,
|
|
126
128
|
trim_offsets: bool = True,
|
|
127
|
-
vocab: Optional[dict] = None,
|
|
128
|
-
merges: Optional[list] = None,
|
|
129
129
|
**kwargs,
|
|
130
130
|
):
|
|
131
131
|
self.add_prefix_space = add_prefix_space
|
|
132
132
|
self.trim_offsets = trim_offsets
|
|
133
133
|
|
|
134
|
-
if vocab is
|
|
135
|
-
|
|
136
|
-
{token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
|
|
137
|
-
)
|
|
138
|
-
else:
|
|
139
|
-
self._vocab = {
|
|
134
|
+
if vocab is None:
|
|
135
|
+
vocab = {
|
|
140
136
|
str(pad_token): 0,
|
|
141
137
|
str(unk_token): 1,
|
|
142
138
|
str(cls_token): 2,
|
|
143
139
|
str(sep_token): 3,
|
|
144
140
|
str(mask_token): 4,
|
|
145
141
|
}
|
|
142
|
+
self._vocab = vocab
|
|
146
143
|
|
|
147
|
-
|
|
148
|
-
self._merges = [tuple(merge) if isinstance(merge, list) else merge for merge in merges]
|
|
149
|
-
else:
|
|
150
|
-
self._merges = []
|
|
144
|
+
self._merges = merges or []
|
|
151
145
|
|
|
152
146
|
self._tokenizer = Tokenizer(
|
|
153
147
|
BPE(
|
|
@@ -162,17 +156,8 @@ class RobertaTokenizer(TokenizersBackend):
|
|
|
162
156
|
|
|
163
157
|
self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
|
|
164
158
|
self._tokenizer.decoder = decoders.ByteLevel()
|
|
165
|
-
self._tokenizer.post_processor = processors.RobertaProcessing(
|
|
166
|
-
sep=(str(sep_token), self._vocab.get(str(sep_token), 3)),
|
|
167
|
-
cls=(str(cls_token), self._vocab.get(str(cls_token), 2)),
|
|
168
|
-
add_prefix_space=add_prefix_space,
|
|
169
|
-
trim_offsets=trim_offsets,
|
|
170
|
-
)
|
|
171
|
-
|
|
172
|
-
tokenizer_object = self._tokenizer
|
|
173
159
|
|
|
174
160
|
super().__init__(
|
|
175
|
-
tokenizer_object=tokenizer_object,
|
|
176
161
|
errors=errors,
|
|
177
162
|
bos_token=bos_token,
|
|
178
163
|
eos_token=eos_token,
|
|
@@ -185,6 +170,12 @@ class RobertaTokenizer(TokenizersBackend):
|
|
|
185
170
|
trim_offsets=trim_offsets,
|
|
186
171
|
**kwargs,
|
|
187
172
|
)
|
|
173
|
+
self._tokenizer.post_processor = processors.RobertaProcessing(
|
|
174
|
+
sep=(str(sep_token), self.sep_token_id),
|
|
175
|
+
cls=(str(cls_token), self.cls_token_id),
|
|
176
|
+
add_prefix_space=add_prefix_space,
|
|
177
|
+
trim_offsets=trim_offsets,
|
|
178
|
+
)
|
|
188
179
|
|
|
189
180
|
|
|
190
181
|
__all__ = ["RobertaTokenizer"]
|
|
@@ -561,6 +561,9 @@ class RobertaPreLayerNormPreTrainedModel(PreTrainedModel):
|
|
|
561
561
|
super()._init_weights(module)
|
|
562
562
|
if isinstance(module, RobertaPreLayerNormLMHead):
|
|
563
563
|
init.zeros_(module.bias)
|
|
564
|
+
elif isinstance(module, RobertaPreLayerNormEmbeddings):
|
|
565
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
566
|
+
init.zeros_(module.token_type_ids)
|
|
564
567
|
|
|
565
568
|
|
|
566
569
|
@auto_docstring(
|
|
@@ -621,6 +621,9 @@ class RoCBertPreTrainedModel(PreTrainedModel):
|
|
|
621
621
|
super()._init_weights(module)
|
|
622
622
|
if isinstance(module, RoCBertLMPredictionHead):
|
|
623
623
|
init.zeros_(module.bias)
|
|
624
|
+
elif isinstance(module, RoCBertEmbeddings):
|
|
625
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
626
|
+
init.zeros_(module.token_type_ids)
|
|
624
627
|
|
|
625
628
|
|
|
626
629
|
@auto_docstring(
|