transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
transformers/models/__init__.py
CHANGED
|
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
|
|
23
23
|
from .albert import *
|
|
24
24
|
from .align import *
|
|
25
25
|
from .altclip import *
|
|
26
|
+
from .apertus import *
|
|
26
27
|
from .arcee import *
|
|
27
28
|
from .aria import *
|
|
28
29
|
from .audio_spectrogram_transformer import *
|
|
@@ -107,6 +108,7 @@ if TYPE_CHECKING:
|
|
|
107
108
|
from .dinov3_vit import *
|
|
108
109
|
from .distilbert import *
|
|
109
110
|
from .dit import *
|
|
111
|
+
from .doge import *
|
|
110
112
|
from .donut import *
|
|
111
113
|
from .dots1 import *
|
|
112
114
|
from .dpr import *
|
|
@@ -119,13 +121,18 @@ if TYPE_CHECKING:
|
|
|
119
121
|
from .emu3 import *
|
|
120
122
|
from .encodec import *
|
|
121
123
|
from .encoder_decoder import *
|
|
124
|
+
from .eomt import *
|
|
122
125
|
from .ernie import *
|
|
126
|
+
from .ernie4_5 import *
|
|
127
|
+
from .ernie4_5_moe import *
|
|
128
|
+
from .ernie4_5_vl_moe import *
|
|
123
129
|
from .esm import *
|
|
124
130
|
from .evolla import *
|
|
125
131
|
from .exaone4 import *
|
|
126
132
|
from .falcon import *
|
|
127
133
|
from .falcon_h1 import *
|
|
128
134
|
from .falcon_mamba import *
|
|
135
|
+
from .fast_vlm import *
|
|
129
136
|
from .fastspeech2_conformer import *
|
|
130
137
|
from .flaubert import *
|
|
131
138
|
from .flava import *
|
|
@@ -143,9 +150,11 @@ if TYPE_CHECKING:
|
|
|
143
150
|
from .git import *
|
|
144
151
|
from .glm import *
|
|
145
152
|
from .glm4 import *
|
|
153
|
+
from .glm4_moe import *
|
|
146
154
|
from .glm4v import *
|
|
147
155
|
from .glm4v_moe import *
|
|
148
156
|
from .glm46v import *
|
|
157
|
+
from .glmasr import *
|
|
149
158
|
from .glpn import *
|
|
150
159
|
from .got_ocr2 import *
|
|
151
160
|
from .gpt2 import *
|
|
@@ -180,11 +189,14 @@ if TYPE_CHECKING:
|
|
|
180
189
|
from .instructblip import *
|
|
181
190
|
from .instructblipvideo import *
|
|
182
191
|
from .internvl import *
|
|
192
|
+
from .jais2 import *
|
|
183
193
|
from .jamba import *
|
|
184
194
|
from .janus import *
|
|
185
195
|
from .jetmoe import *
|
|
186
196
|
from .kosmos2 import *
|
|
197
|
+
from .kosmos2_5 import *
|
|
187
198
|
from .kyutai_speech_to_text import *
|
|
199
|
+
from .lasr import *
|
|
188
200
|
from .layoutlm import *
|
|
189
201
|
from .layoutlmv2 import *
|
|
190
202
|
from .layoutlmv3 import *
|
|
@@ -218,6 +230,7 @@ if TYPE_CHECKING:
|
|
|
218
230
|
from .mbart50 import *
|
|
219
231
|
from .megatron_bert import *
|
|
220
232
|
from .megatron_gpt2 import *
|
|
233
|
+
from .metaclip_2 import *
|
|
221
234
|
from .mgp_str import *
|
|
222
235
|
from .mimi import *
|
|
223
236
|
from .minimax import *
|
|
@@ -229,6 +242,7 @@ if TYPE_CHECKING:
|
|
|
229
242
|
from .mlcd import *
|
|
230
243
|
from .mllama import *
|
|
231
244
|
from .mluke import *
|
|
245
|
+
from .mm_grounding_dino import *
|
|
232
246
|
from .mobilebert import *
|
|
233
247
|
from .mobilenet_v1 import *
|
|
234
248
|
from .mobilenet_v2 import *
|
|
@@ -263,10 +277,14 @@ if TYPE_CHECKING:
|
|
|
263
277
|
from .ovis2 import *
|
|
264
278
|
from .owlv2 import *
|
|
265
279
|
from .owlvit import *
|
|
280
|
+
from .paddleocr_vl import *
|
|
266
281
|
from .paligemma import *
|
|
267
282
|
from .parakeet import *
|
|
268
283
|
from .patchtsmixer import *
|
|
269
284
|
from .patchtst import *
|
|
285
|
+
from .pe_audio import *
|
|
286
|
+
from .pe_audio_video import *
|
|
287
|
+
from .pe_video import *
|
|
270
288
|
from .pegasus import *
|
|
271
289
|
from .pegasus_x import *
|
|
272
290
|
from .perceiver import *
|
|
@@ -278,6 +296,7 @@ if TYPE_CHECKING:
|
|
|
278
296
|
from .phimoe import *
|
|
279
297
|
from .phobert import *
|
|
280
298
|
from .pix2struct import *
|
|
299
|
+
from .pixio import *
|
|
281
300
|
from .pixtral import *
|
|
282
301
|
from .plbart import *
|
|
283
302
|
from .poolformer import *
|
|
@@ -314,8 +333,10 @@ if TYPE_CHECKING:
|
|
|
314
333
|
from .sam import *
|
|
315
334
|
from .sam2 import *
|
|
316
335
|
from .sam2_video import *
|
|
336
|
+
from .sam3 import *
|
|
317
337
|
from .sam3_tracker import *
|
|
318
338
|
from .sam3_tracker_video import *
|
|
339
|
+
from .sam3_video import *
|
|
319
340
|
from .sam_hq import *
|
|
320
341
|
from .seamless_m4t import *
|
|
321
342
|
from .seamless_m4t_v2 import *
|
|
@@ -327,6 +348,7 @@ if TYPE_CHECKING:
|
|
|
327
348
|
from .shieldgemma2 import *
|
|
328
349
|
from .siglip import *
|
|
329
350
|
from .siglip2 import *
|
|
351
|
+
from .smollm3 import *
|
|
330
352
|
from .smolvlm import *
|
|
331
353
|
from .speech_encoder_decoder import *
|
|
332
354
|
from .speech_to_text import *
|
|
@@ -25,11 +25,11 @@ from typing import Optional, Union
|
|
|
25
25
|
import torch
|
|
26
26
|
from torch import nn
|
|
27
27
|
|
|
28
|
+
from ... import initialization as init
|
|
28
29
|
from ...activations import ACT2FN
|
|
29
30
|
from ...cache_utils import Cache, DynamicCache
|
|
30
31
|
from ...generation import GenerationMixin
|
|
31
|
-
from ...integrations import use_kernel_func_from_hub
|
|
32
|
-
from ...integrations.hub_kernels import use_kernel_forward_from_hub
|
|
32
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
33
33
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
34
34
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
35
35
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MoeModelOutputWithPast
|
|
@@ -37,7 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
37
37
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
38
38
|
from ...processing_utils import Unpack
|
|
39
39
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
40
|
-
from ...utils.generic import check_model_inputs
|
|
40
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
41
41
|
from .configuration_afmoe import AfmoeConfig
|
|
42
42
|
|
|
43
43
|
|
|
@@ -58,7 +58,7 @@ class AfmoeRotaryEmbedding(nn.Module):
|
|
|
58
58
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
59
59
|
|
|
60
60
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
61
|
-
self.original_inv_freq =
|
|
61
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
62
62
|
|
|
63
63
|
@staticmethod
|
|
64
64
|
def compute_default_rope_parameters(
|
|
@@ -97,7 +97,7 @@ class AfmoeRotaryEmbedding(nn.Module):
|
|
|
97
97
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
98
98
|
|
|
99
99
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
100
|
-
with
|
|
100
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
101
101
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
102
102
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
103
103
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -338,6 +338,7 @@ def eager_attention_forward(
|
|
|
338
338
|
return attn_output, attn_weights
|
|
339
339
|
|
|
340
340
|
|
|
341
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
341
342
|
class AfmoeAttention(nn.Module):
|
|
342
343
|
"""
|
|
343
344
|
Multi-headed attention module with optional sliding window and gating.
|
|
@@ -369,7 +370,6 @@ class AfmoeAttention(nn.Module):
|
|
|
369
370
|
self.o_proj = nn.Linear(
|
|
370
371
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
371
372
|
)
|
|
372
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
373
373
|
# Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
|
|
374
374
|
# We only add AFMoE-specific attributes
|
|
375
375
|
self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
|
|
@@ -531,20 +531,11 @@ class AfmoePreTrainedModel(PreTrainedModel):
|
|
|
531
531
|
|
|
532
532
|
def _init_weights(self, module):
|
|
533
533
|
"""Initialize the weights"""
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
nn.init.zeros_(module.bias)
|
|
538
|
-
elif isinstance(module, nn.Embedding):
|
|
539
|
-
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
540
|
-
if module.padding_idx is not None:
|
|
541
|
-
nn.init.zeros_(module.weight[module.padding_idx])
|
|
542
|
-
elif isinstance(module, AfmoeRMSNorm):
|
|
543
|
-
nn.init.ones_(module.weight)
|
|
544
|
-
elif isinstance(module, AfmoeTokenChoiceRouter):
|
|
545
|
-
nn.init.zeros_(module.gate.weight)
|
|
534
|
+
super()._init_weights(module)
|
|
535
|
+
if isinstance(module, AfmoeTokenChoiceRouter):
|
|
536
|
+
init.zeros_(module.gate.weight)
|
|
546
537
|
elif isinstance(module, AfmoeMoE):
|
|
547
|
-
|
|
538
|
+
init.zeros_(module.expert_bias)
|
|
548
539
|
|
|
549
540
|
|
|
550
541
|
@auto_docstring
|
|
@@ -20,6 +20,7 @@ from typing import Optional
|
|
|
20
20
|
import torch
|
|
21
21
|
from torch import nn
|
|
22
22
|
|
|
23
|
+
from ... import initialization as init
|
|
23
24
|
from ...cache_utils import Cache, DynamicCache
|
|
24
25
|
from ...generation import GenerationMixin
|
|
25
26
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
@@ -350,20 +351,11 @@ class AfmoePreTrainedModel(PreTrainedModel):
|
|
|
350
351
|
|
|
351
352
|
def _init_weights(self, module):
|
|
352
353
|
"""Initialize the weights"""
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
nn.init.zeros_(module.bias)
|
|
357
|
-
elif isinstance(module, nn.Embedding):
|
|
358
|
-
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
359
|
-
if module.padding_idx is not None:
|
|
360
|
-
nn.init.zeros_(module.weight[module.padding_idx])
|
|
361
|
-
elif isinstance(module, AfmoeRMSNorm):
|
|
362
|
-
nn.init.ones_(module.weight)
|
|
363
|
-
elif isinstance(module, AfmoeTokenChoiceRouter):
|
|
364
|
-
nn.init.zeros_(module.gate.weight)
|
|
354
|
+
super()._init_weights(module)
|
|
355
|
+
if isinstance(module, AfmoeTokenChoiceRouter):
|
|
356
|
+
init.zeros_(module.gate.weight)
|
|
365
357
|
elif isinstance(module, AfmoeMoE):
|
|
366
|
-
|
|
358
|
+
init.zeros_(module.expert_bias)
|
|
367
359
|
|
|
368
360
|
|
|
369
361
|
@auto_docstring
|
|
@@ -414,6 +414,10 @@ class Aimv2PreTrainedModel(PreTrainedModel):
|
|
|
414
414
|
init.constant_(module.logit_scale, math.log(1 / 0.07))
|
|
415
415
|
elif isinstance(module, Aimv2AttentionPoolingHead):
|
|
416
416
|
init.normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
|
|
417
|
+
elif isinstance(module, Aimv2VisionEmbeddings):
|
|
418
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
419
|
+
elif isinstance(module, Aimv2TextEmbeddings):
|
|
420
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
417
421
|
|
|
418
422
|
|
|
419
423
|
@auto_docstring(
|
|
@@ -457,6 +457,10 @@ class Aimv2PreTrainedModel(PreTrainedModel):
|
|
|
457
457
|
init.constant_(module.logit_scale, math.log(1 / 0.07))
|
|
458
458
|
elif isinstance(module, Aimv2AttentionPoolingHead):
|
|
459
459
|
init.normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
|
|
460
|
+
elif isinstance(module, Aimv2VisionEmbeddings):
|
|
461
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
462
|
+
elif isinstance(module, Aimv2TextEmbeddings):
|
|
463
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
460
464
|
|
|
461
465
|
|
|
462
466
|
@auto_docstring(
|
|
@@ -320,6 +320,9 @@ class AlbertPreTrainedModel(PreTrainedModel):
|
|
|
320
320
|
init.ones_(module.weight)
|
|
321
321
|
elif isinstance(module, AlbertMLMHead):
|
|
322
322
|
init.zeros_(module.bias)
|
|
323
|
+
elif isinstance(module, AlbertEmbeddings):
|
|
324
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
325
|
+
init.zeros_(module.token_type_ids)
|
|
323
326
|
|
|
324
327
|
|
|
325
328
|
@dataclass
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Tokenization classes for ALBERT model."""
|
|
16
16
|
|
|
17
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
|
20
20
|
from tokenizers.models import Unigram
|
|
@@ -73,8 +73,8 @@ class AlbertTokenizer(TokenizersBackend):
|
|
|
73
73
|
other word.
|
|
74
74
|
trim_offsets (`bool`, *optional*, defaults to `True`):
|
|
75
75
|
Whether the post processing step should trim offsets to avoid including whitespaces.
|
|
76
|
-
vocab (`
|
|
77
|
-
Custom vocabulary
|
|
76
|
+
vocab (`str` or `list[tuple[str, float]]`, *optional*):
|
|
77
|
+
Custom vocabulary with `(token, score)` tuples. If not provided, vocabulary is loaded from `vocab_file`.
|
|
78
78
|
vocab_file (`str`, *optional*):
|
|
79
79
|
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
|
|
80
80
|
contains the vocabulary necessary to instantiate a tokenizer.
|
|
@@ -82,10 +82,11 @@ class AlbertTokenizer(TokenizersBackend):
|
|
|
82
82
|
|
|
83
83
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
84
84
|
model_input_names = ["input_ids", "attention_mask"]
|
|
85
|
-
|
|
85
|
+
model = Unigram
|
|
86
86
|
|
|
87
87
|
def __init__(
|
|
88
88
|
self,
|
|
89
|
+
vocab: Optional[Union[str, list[tuple[str, float]]]] = None,
|
|
89
90
|
do_lower_case: bool = True,
|
|
90
91
|
keep_accents: bool = False,
|
|
91
92
|
bos_token: str = "[CLS]",
|
|
@@ -97,19 +98,15 @@ class AlbertTokenizer(TokenizersBackend):
|
|
|
97
98
|
mask_token: str = "[MASK]",
|
|
98
99
|
add_prefix_space: bool = True,
|
|
99
100
|
trim_offsets: bool = True,
|
|
100
|
-
vocab: Optional[dict] = None,
|
|
101
|
-
vocab_file: Optional[str] = None,
|
|
102
101
|
**kwargs,
|
|
103
102
|
):
|
|
104
|
-
self.vocab_file = vocab_file
|
|
105
103
|
self.add_prefix_space = add_prefix_space
|
|
106
104
|
self.trim_offsets = trim_offsets
|
|
107
|
-
|
|
108
105
|
self.do_lower_case = do_lower_case
|
|
109
106
|
self.keep_accents = keep_accents
|
|
110
107
|
|
|
111
108
|
if vocab is not None:
|
|
112
|
-
self._vocab_scores =
|
|
109
|
+
self._vocab_scores = vocab
|
|
113
110
|
else:
|
|
114
111
|
self._vocab_scores = [
|
|
115
112
|
(str(pad_token), 0.0),
|
|
@@ -163,10 +160,7 @@ class AlbertTokenizer(TokenizersBackend):
|
|
|
163
160
|
],
|
|
164
161
|
)
|
|
165
162
|
|
|
166
|
-
tokenizer_object = self._tokenizer
|
|
167
|
-
|
|
168
163
|
super().__init__(
|
|
169
|
-
tokenizer_object=tokenizer_object,
|
|
170
164
|
do_lower_case=self.do_lower_case,
|
|
171
165
|
keep_accents=self.keep_accents,
|
|
172
166
|
bos_token=bos_token,
|
|
@@ -781,9 +781,9 @@ class AlignTextEncoder(nn.Module):
|
|
|
781
781
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
782
782
|
|
|
783
783
|
layer_outputs = layer_module(
|
|
784
|
-
hidden_states
|
|
785
|
-
attention_mask
|
|
786
|
-
output_attentions
|
|
784
|
+
hidden_states,
|
|
785
|
+
attention_mask,
|
|
786
|
+
output_attentions,
|
|
787
787
|
**kwargs,
|
|
788
788
|
)
|
|
789
789
|
|
|
@@ -844,6 +844,13 @@ class AlignPreTrainedModel(PreTrainedModel):
|
|
|
844
844
|
if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
|
|
845
845
|
init.zeros_(module.bias)
|
|
846
846
|
init.ones_(module.weight)
|
|
847
|
+
if getattr(module, "running_mean", None) is not None:
|
|
848
|
+
init.zeros_(module.running_mean)
|
|
849
|
+
init.ones_(module.running_var)
|
|
850
|
+
init.zeros_(module.num_batches_tracked)
|
|
851
|
+
elif isinstance(module, AlignTextEmbeddings):
|
|
852
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
853
|
+
init.zeros_(module.token_type_ids)
|
|
847
854
|
|
|
848
855
|
|
|
849
856
|
@auto_docstring(
|
|
@@ -976,6 +983,8 @@ class AlignVisionModel(AlignPreTrainedModel):
|
|
|
976
983
|
main_input_name = "pixel_values"
|
|
977
984
|
input_modalities = ("image",)
|
|
978
985
|
supports_gradient_checkpointing = False
|
|
986
|
+
_input_embed_layer = "convolution"
|
|
987
|
+
_no_split_modules = ["AlignVisionBlock"]
|
|
979
988
|
|
|
980
989
|
def __init__(self, config: AlignVisionConfig):
|
|
981
990
|
super().__init__(config)
|
|
@@ -994,9 +1003,6 @@ class AlignVisionModel(AlignPreTrainedModel):
|
|
|
994
1003
|
# Initialize weights and apply final processing
|
|
995
1004
|
self.post_init()
|
|
996
1005
|
|
|
997
|
-
def get_input_embeddings(self) -> nn.Module:
|
|
998
|
-
return self.vision_model.embeddings.convolution
|
|
999
|
-
|
|
1000
1006
|
@can_return_tuple
|
|
1001
1007
|
@auto_docstring
|
|
1002
1008
|
def forward(
|
|
@@ -1004,6 +1010,7 @@ class AlignVisionModel(AlignPreTrainedModel):
|
|
|
1004
1010
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
1005
1011
|
output_hidden_states: Optional[bool] = None,
|
|
1006
1012
|
return_dict: Optional[bool] = None,
|
|
1013
|
+
**kwargs,
|
|
1007
1014
|
) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
|
|
1008
1015
|
r"""
|
|
1009
1016
|
Examples:
|
|
@@ -1169,6 +1176,7 @@ class AlignModel(AlignPreTrainedModel):
|
|
|
1169
1176
|
output_attentions: Optional[bool] = None,
|
|
1170
1177
|
output_hidden_states: Optional[bool] = None,
|
|
1171
1178
|
return_dict: Optional[bool] = None,
|
|
1179
|
+
**kwargs,
|
|
1172
1180
|
) -> Union[tuple, AlignOutput]:
|
|
1173
1181
|
r"""
|
|
1174
1182
|
return_loss (`bool`, *optional*):
|
|
@@ -393,9 +393,9 @@ class AltRobertaEncoder(nn.Module):
|
|
|
393
393
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
394
394
|
|
|
395
395
|
layer_outputs = layer_module(
|
|
396
|
-
hidden_states
|
|
397
|
-
attention_mask
|
|
398
|
-
output_attentions
|
|
396
|
+
hidden_states,
|
|
397
|
+
attention_mask,
|
|
398
|
+
output_attentions,
|
|
399
399
|
**kwargs,
|
|
400
400
|
)
|
|
401
401
|
|
|
@@ -780,6 +780,7 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
|
|
|
780
780
|
init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
|
781
781
|
init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
|
782
782
|
init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
|
783
|
+
init.copy_(module.position_ids, torch.arange(module.num_positions).expand((1, -1)))
|
|
783
784
|
elif isinstance(module, AltCLIPAttention):
|
|
784
785
|
factor = self.config.initializer_factor
|
|
785
786
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
@@ -815,6 +816,9 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
|
|
|
815
816
|
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
816
817
|
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
|
|
817
818
|
init.zeros_(module.weight[module.padding_idx])
|
|
819
|
+
elif isinstance(module, AltRobertaEmbeddings):
|
|
820
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
821
|
+
init.zeros_(module.token_type_ids)
|
|
818
822
|
|
|
819
823
|
|
|
820
824
|
class AltCLIPVisionTransformer(nn.Module):
|
|
@@ -891,6 +895,7 @@ class AltCLIPVisionModel(AltCLIPPreTrainedModel):
|
|
|
891
895
|
output_hidden_states: Optional[bool] = None,
|
|
892
896
|
interpolate_pos_encoding: bool = False,
|
|
893
897
|
return_dict: Optional[bool] = None,
|
|
898
|
+
**kwargs,
|
|
894
899
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
|
895
900
|
r"""
|
|
896
901
|
Examples:
|
|
@@ -970,6 +975,7 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
|
|
970
975
|
output_attentions: Optional[bool] = None,
|
|
971
976
|
output_hidden_states: Optional[bool] = None,
|
|
972
977
|
return_dict: Optional[bool] = None,
|
|
978
|
+
**kwargs,
|
|
973
979
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
|
974
980
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
975
981
|
output_hidden_states = (
|
|
@@ -1061,6 +1067,7 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
|
|
|
1061
1067
|
output_attentions: Optional[bool] = None,
|
|
1062
1068
|
return_dict: Optional[bool] = None,
|
|
1063
1069
|
output_hidden_states: Optional[bool] = None,
|
|
1070
|
+
**kwargs,
|
|
1064
1071
|
) -> Union[tuple, BaseModelOutputWithPoolingAndProjection]:
|
|
1065
1072
|
r"""
|
|
1066
1073
|
Examples:
|
|
@@ -1236,6 +1243,7 @@ class AltCLIPModel(AltCLIPPreTrainedModel):
|
|
|
1236
1243
|
output_hidden_states: Optional[bool] = None,
|
|
1237
1244
|
interpolate_pos_encoding: bool = False,
|
|
1238
1245
|
return_dict: Optional[bool] = None,
|
|
1246
|
+
**kwargs,
|
|
1239
1247
|
) -> Union[tuple, AltCLIPOutput]:
|
|
1240
1248
|
r"""
|
|
1241
1249
|
return_loss (`bool`, *optional*):
|
|
@@ -25,10 +25,10 @@ from typing import Optional, Union
|
|
|
25
25
|
import torch
|
|
26
26
|
from torch import nn
|
|
27
27
|
|
|
28
|
-
from ...activations import ACT2FN
|
|
28
|
+
from ...activations import ACT2CLS, ACT2FN
|
|
29
29
|
from ...cache_utils import Cache, DynamicCache
|
|
30
30
|
from ...generation import GenerationMixin
|
|
31
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
31
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
32
32
|
from ...masking_utils import create_causal_mask
|
|
33
33
|
from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer
|
|
34
34
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -36,7 +36,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
36
36
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
37
37
|
from ...processing_utils import Unpack
|
|
38
38
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
39
|
-
from ...utils.generic import check_model_inputs
|
|
39
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
40
40
|
from .configuration_apertus import ApertusConfig
|
|
41
41
|
|
|
42
42
|
|
|
@@ -49,6 +49,8 @@ class ApertusMLP(nn.Module):
|
|
|
49
49
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
50
50
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
51
51
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
52
|
+
if config.hidden_act == "xielu":
|
|
53
|
+
self.act_fn = ACT2CLS["xielu"](dtype=config.dtype)
|
|
52
54
|
|
|
53
55
|
def forward(self, x):
|
|
54
56
|
return self.down_proj(self.act_fn(self.up_proj(x)))
|
|
@@ -92,7 +94,7 @@ class ApertusRotaryEmbedding(nn.Module):
|
|
|
92
94
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
93
95
|
|
|
94
96
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
95
|
-
self.original_inv_freq =
|
|
97
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
96
98
|
|
|
97
99
|
@staticmethod
|
|
98
100
|
def compute_default_rope_parameters(
|
|
@@ -131,7 +133,7 @@ class ApertusRotaryEmbedding(nn.Module):
|
|
|
131
133
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
132
134
|
|
|
133
135
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
134
|
-
with
|
|
136
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
135
137
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
136
138
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
137
139
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -213,6 +215,7 @@ def eager_attention_forward(
|
|
|
213
215
|
return attn_output, attn_weights
|
|
214
216
|
|
|
215
217
|
|
|
218
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
216
219
|
class ApertusAttention(nn.Module):
|
|
217
220
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
218
221
|
|
|
@@ -238,7 +241,6 @@ class ApertusAttention(nn.Module):
|
|
|
238
241
|
self.o_proj = nn.Linear(
|
|
239
242
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
240
243
|
)
|
|
241
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
242
244
|
self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
|
|
243
245
|
self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
|
|
244
246
|
|
|
@@ -19,6 +19,7 @@ from typing import Optional
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch import nn
|
|
21
21
|
|
|
22
|
+
from ...activations import ACT2CLS
|
|
22
23
|
from ...cache_utils import Cache
|
|
23
24
|
from ...configuration_utils import PreTrainedConfig
|
|
24
25
|
from ...modeling_rope_utils import RopeParameters
|
|
@@ -192,9 +193,11 @@ class ApertusConfig(PreTrainedConfig):
|
|
|
192
193
|
|
|
193
194
|
class ApertusMLP(NemotronMLP):
|
|
194
195
|
def __init__(self, config):
|
|
195
|
-
super().__init__()
|
|
196
|
+
super().__init__(config)
|
|
196
197
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
197
198
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
199
|
+
if config.hidden_act == "xielu":
|
|
200
|
+
self.act_fn = ACT2CLS["xielu"](dtype=config.dtype)
|
|
198
201
|
|
|
199
202
|
|
|
200
203
|
class ApertusRMSNorm(LlamaRMSNorm):
|
|
@@ -30,7 +30,7 @@ from transformers.utils import auto_docstring
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
33
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask
|
|
35
35
|
from ...modeling_layers import (
|
|
36
36
|
GenericForQuestionAnswering,
|
|
@@ -43,7 +43,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
43
43
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
44
44
|
from ...processing_utils import Unpack
|
|
45
45
|
from ...utils import TransformersKwargs, can_return_tuple
|
|
46
|
-
from ...utils.generic import check_model_inputs
|
|
46
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
47
47
|
from .configuration_arcee import ArceeConfig
|
|
48
48
|
|
|
49
49
|
|
|
@@ -99,7 +99,7 @@ class ArceeRotaryEmbedding(nn.Module):
|
|
|
99
99
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
100
100
|
|
|
101
101
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
102
|
-
self.original_inv_freq =
|
|
102
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
103
103
|
|
|
104
104
|
@staticmethod
|
|
105
105
|
def compute_default_rope_parameters(
|
|
@@ -138,7 +138,7 @@ class ArceeRotaryEmbedding(nn.Module):
|
|
|
138
138
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
139
139
|
|
|
140
140
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
141
|
-
with
|
|
141
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
142
142
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
143
143
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
144
144
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -220,6 +220,7 @@ def eager_attention_forward(
|
|
|
220
220
|
return attn_output, attn_weights
|
|
221
221
|
|
|
222
222
|
|
|
223
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
223
224
|
class ArceeAttention(nn.Module):
|
|
224
225
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
225
226
|
|
|
@@ -245,7 +246,6 @@ class ArceeAttention(nn.Module):
|
|
|
245
246
|
self.o_proj = nn.Linear(
|
|
246
247
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
247
248
|
)
|
|
248
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
249
249
|
|
|
250
250
|
def forward(
|
|
251
251
|
self,
|
|
@@ -29,7 +29,7 @@ from ... import initialization as init
|
|
|
29
29
|
from ...activations import ACT2FN
|
|
30
30
|
from ...cache_utils import Cache, DynamicCache
|
|
31
31
|
from ...generation import GenerationMixin
|
|
32
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
32
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
33
33
|
from ...masking_utils import create_causal_mask
|
|
34
34
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
35
35
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
38
38
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
40
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
41
|
-
from ...utils.generic import check_model_inputs
|
|
41
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
42
42
|
from ..auto import AutoModel
|
|
43
43
|
from .configuration_aria import AriaConfig, AriaTextConfig
|
|
44
44
|
|
|
@@ -444,6 +444,7 @@ def eager_attention_forward(
|
|
|
444
444
|
return attn_output, attn_weights
|
|
445
445
|
|
|
446
446
|
|
|
447
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
447
448
|
class AriaTextAttention(nn.Module):
|
|
448
449
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
449
450
|
|
|
@@ -469,7 +470,6 @@ class AriaTextAttention(nn.Module):
|
|
|
469
470
|
self.o_proj = nn.Linear(
|
|
470
471
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
471
472
|
)
|
|
472
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
473
473
|
|
|
474
474
|
def forward(
|
|
475
475
|
self,
|
|
@@ -636,7 +636,7 @@ class AriaTextRotaryEmbedding(nn.Module):
|
|
|
636
636
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
637
637
|
|
|
638
638
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
639
|
-
self.original_inv_freq =
|
|
639
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
640
640
|
|
|
641
641
|
@staticmethod
|
|
642
642
|
def compute_default_rope_parameters(
|
|
@@ -675,7 +675,7 @@ class AriaTextRotaryEmbedding(nn.Module):
|
|
|
675
675
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
676
676
|
|
|
677
677
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
678
|
-
with
|
|
678
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
679
679
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
680
680
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
681
681
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -1203,6 +1203,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
|
|
1203
1203
|
attention_mask=None,
|
|
1204
1204
|
cache_position=None,
|
|
1205
1205
|
logits_to_keep=None,
|
|
1206
|
+
is_first_iteration=False,
|
|
1206
1207
|
**kwargs,
|
|
1207
1208
|
):
|
|
1208
1209
|
model_inputs = super().prepare_inputs_for_generation(
|
|
@@ -1212,12 +1213,15 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
|
|
1212
1213
|
attention_mask=attention_mask,
|
|
1213
1214
|
cache_position=cache_position,
|
|
1214
1215
|
logits_to_keep=logits_to_keep,
|
|
1216
|
+
is_first_iteration=is_first_iteration,
|
|
1215
1217
|
**kwargs,
|
|
1216
1218
|
)
|
|
1217
1219
|
|
|
1218
|
-
if
|
|
1219
|
-
#
|
|
1220
|
-
#
|
|
1220
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
1221
|
+
# Pixel values are used only in the first iteration if available
|
|
1222
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
1223
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
1224
|
+
# iteration with a question and cached system prompt (continue generate from cache)
|
|
1221
1225
|
model_inputs["pixel_values"] = pixel_values
|
|
1222
1226
|
model_inputs["pixel_mask"] = pixel_mask
|
|
1223
1227
|
|