transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
|
|
20
|
+
from ...activations import ACT2FN
|
|
21
|
+
from ...configuration_utils import PreTrainedConfig
|
|
22
|
+
from ...utils import auto_docstring
|
|
23
|
+
from ..auto import CONFIG_MAPPING
|
|
24
|
+
from ..llava.configuration_llava import LlavaConfig
|
|
25
|
+
from ..llava.modeling_llava import (
|
|
26
|
+
LlavaForConditionalGeneration,
|
|
27
|
+
LlavaModel,
|
|
28
|
+
LlavaMultiModalProjector,
|
|
29
|
+
LlavaPreTrainedModel,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class FastVlmConfig(LlavaConfig):
|
|
34
|
+
r"""
|
|
35
|
+
This is the configuration class to store the configuration of a [`FastVlmForConditionalGeneration`]. It is used to instantiate a
|
|
36
|
+
FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
|
37
|
+
with the defaults will yield the same configuration as the one of FastVLM-7B.
|
|
38
|
+
|
|
39
|
+
e.g. [KamilaMila/FastVLM-7B](https://huggingface.co/KamilaMila/FastVLM-7B)
|
|
40
|
+
|
|
41
|
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
42
|
+
documentation from [`PretrainedConfig`] for more information.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `TimmWrapperConfig` for `fastvit_mci3`):
|
|
46
|
+
The config object or dictionary of the vision backbone.
|
|
47
|
+
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
|
|
48
|
+
The config object or dictionary of the text backbone.
|
|
49
|
+
image_token_id (`int`, *optional*, defaults to 151646):
|
|
50
|
+
The image token index to encode the image prompt.
|
|
51
|
+
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
|
52
|
+
The activation function used by the multimodal projector.
|
|
53
|
+
vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`):
|
|
54
|
+
The feature selection strategy used to select the vision feature from the vision backbone.
|
|
55
|
+
Only "full" supported.
|
|
56
|
+
vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1):
|
|
57
|
+
The index of the layer to select the vision feature. If multiple indices are provided,
|
|
58
|
+
the vision feature of the corresponding indices will be concatenated to form the
|
|
59
|
+
vision features. Only -1 supported.
|
|
60
|
+
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
|
61
|
+
Whether to use bias in the multimodal projector.
|
|
62
|
+
|
|
63
|
+
Example:
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
>>> from transformers import FastVlmForConditionalGeneration, FastVlmConfig
|
|
67
|
+
|
|
68
|
+
>>> # Initializing a FastVLM-7B style configuration
|
|
69
|
+
>>> configuration = FastVlmConfig()
|
|
70
|
+
|
|
71
|
+
>>> # Initializing a model from the FastVLM-7B style configuration
|
|
72
|
+
>>> model = FastVlmForConditionalGeneration(configuration)
|
|
73
|
+
|
|
74
|
+
>>> # Accessing the model configuration
|
|
75
|
+
>>> configuration = model.config
|
|
76
|
+
```"""
|
|
77
|
+
|
|
78
|
+
model_type = "fast_vlm"
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
vision_config=None,
|
|
83
|
+
text_config=None,
|
|
84
|
+
image_token_id=151646,
|
|
85
|
+
projector_hidden_act="gelu",
|
|
86
|
+
vision_feature_select_strategy="full",
|
|
87
|
+
vision_feature_layer=-1,
|
|
88
|
+
multimodal_projector_bias=True,
|
|
89
|
+
**kwargs,
|
|
90
|
+
):
|
|
91
|
+
self.image_token_id = image_token_id
|
|
92
|
+
self.projector_hidden_act = projector_hidden_act
|
|
93
|
+
|
|
94
|
+
if vision_feature_select_strategy != "full":
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"Unexpected select feature strategy: {vision_feature_select_strategy}. Only 'full' is supported in FastVLM."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if vision_feature_layer != -1:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f"Unexpected vision feature layer: {vision_feature_layer}. Only -1 is supported in FastVLM."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
self.vision_feature_select_strategy = vision_feature_select_strategy
|
|
105
|
+
self.vision_feature_layer = vision_feature_layer
|
|
106
|
+
|
|
107
|
+
if isinstance(vision_config, dict):
|
|
108
|
+
vision_config["model_type"] = vision_config.get("model_type", "timm_wrapper")
|
|
109
|
+
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
|
110
|
+
elif vision_config is None:
|
|
111
|
+
vision_config = CONFIG_MAPPING["timm_wrapper"](
|
|
112
|
+
architecture="fastvit_mci3",
|
|
113
|
+
do_pooling=True,
|
|
114
|
+
global_pool="avg",
|
|
115
|
+
hidden_size=3072,
|
|
116
|
+
initializer_range=0.02,
|
|
117
|
+
model_args={"inference_mode": True},
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self.vision_config = vision_config
|
|
121
|
+
|
|
122
|
+
if isinstance(text_config, dict):
|
|
123
|
+
text_config["model_type"] = text_config.get("model_type", "qwen2")
|
|
124
|
+
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
|
125
|
+
elif text_config is None:
|
|
126
|
+
text_config = CONFIG_MAPPING["qwen2"](
|
|
127
|
+
hidden_size=3584,
|
|
128
|
+
vocab_size=152128,
|
|
129
|
+
intermediate_size=18944,
|
|
130
|
+
num_attention_heads=28,
|
|
131
|
+
num_key_value_heads=4,
|
|
132
|
+
num_hidden_layers=28,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
self.text_config = text_config
|
|
136
|
+
self.multimodal_projector_bias = multimodal_projector_bias
|
|
137
|
+
|
|
138
|
+
PreTrainedConfig.__init__(**kwargs)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class FastVlmMultiModalProjector(LlavaMultiModalProjector):
|
|
142
|
+
def __init__(self, config: FastVlmConfig):
|
|
143
|
+
nn.Module.__init__()
|
|
144
|
+
self.linear_1 = nn.Linear(
|
|
145
|
+
config.vision_config.hidden_size,
|
|
146
|
+
config.text_config.hidden_size,
|
|
147
|
+
bias=config.multimodal_projector_bias,
|
|
148
|
+
)
|
|
149
|
+
self.act = ACT2FN[config.projector_hidden_act]
|
|
150
|
+
self.linear_2 = nn.Linear(
|
|
151
|
+
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class FastVlmPreTrainedModel(LlavaPreTrainedModel):
|
|
156
|
+
pass
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class FastVlmModel(LlavaModel):
|
|
160
|
+
_checkpoint_conversion_mapping = {}
|
|
161
|
+
|
|
162
|
+
def __init__(self, config: FastVlmConfig):
|
|
163
|
+
super().__init__(config)
|
|
164
|
+
|
|
165
|
+
def get_image_features(
|
|
166
|
+
self,
|
|
167
|
+
pixel_values: torch.FloatTensor,
|
|
168
|
+
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
169
|
+
vision_feature_select_strategy: Optional[str] = None,
|
|
170
|
+
**kwargs,
|
|
171
|
+
):
|
|
172
|
+
"""
|
|
173
|
+
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
|
177
|
+
The tensors corresponding to the input images.
|
|
178
|
+
vision_feature_layer (`Union[int, list[int]]`, *optional*):
|
|
179
|
+
The index/indices of the layer to select the vision feature. Only -1 supported.
|
|
180
|
+
vision_feature_select_strategy (`str`, *optional*):
|
|
181
|
+
The feature selection strategy used to select the vision feature from the vision backbone.
|
|
182
|
+
Only "full" supported.
|
|
183
|
+
Returns:
|
|
184
|
+
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
|
185
|
+
"""
|
|
186
|
+
vision_feature_layer = (
|
|
187
|
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
|
188
|
+
)
|
|
189
|
+
vision_feature_select_strategy = (
|
|
190
|
+
vision_feature_select_strategy
|
|
191
|
+
if vision_feature_select_strategy is not None
|
|
192
|
+
else self.config.vision_feature_select_strategy
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
196
|
+
image_outputs = self.vision_tower(pixel_values, **kwargs)
|
|
197
|
+
|
|
198
|
+
# since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava
|
|
199
|
+
selected_image_feature = image_outputs.last_hidden_state
|
|
200
|
+
selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1)
|
|
201
|
+
image_features = self.multi_modal_projector(selected_image_feature)
|
|
202
|
+
image_features = list(image_features)
|
|
203
|
+
return image_features
|
|
204
|
+
|
|
205
|
+
def forward(self, **super_kwargs):
|
|
206
|
+
r"""
|
|
207
|
+
vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*):
|
|
208
|
+
The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the
|
|
209
|
+
corresponding indices will be concatenated to form the vision features. Only -1 supported.
|
|
210
|
+
vision_feature_select_strategy (`str`, *optional*):
|
|
211
|
+
The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported.
|
|
212
|
+
"""
|
|
213
|
+
super().forward(**super_kwargs)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@auto_docstring(
|
|
217
|
+
custom_intro="""
|
|
218
|
+
The FastVlm model which consists of a vision backbone and a language model.
|
|
219
|
+
"""
|
|
220
|
+
)
|
|
221
|
+
class FastVlmForConditionalGeneration(LlavaForConditionalGeneration):
|
|
222
|
+
_checkpoint_conversion_mapping = {}
|
|
223
|
+
|
|
224
|
+
def forward(self, **super_kwargs):
|
|
225
|
+
r"""
|
|
226
|
+
vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*):
|
|
227
|
+
The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the
|
|
228
|
+
corresponding indices will be concatenated to form the vision features. Only -1 supported.
|
|
229
|
+
vision_feature_select_strategy (`str`, *optional*):
|
|
230
|
+
The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported.
|
|
231
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
232
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
233
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
234
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
235
|
+
|
|
236
|
+
Example:
|
|
237
|
+
|
|
238
|
+
```python
|
|
239
|
+
>>> from PIL import Image
|
|
240
|
+
>>> import requests
|
|
241
|
+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
|
242
|
+
>>> import torch
|
|
243
|
+
|
|
244
|
+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
245
|
+
|
|
246
|
+
>>> model = AutoModelForImageTextToText.from_pretrained("KamilaMila/FastVLM-0.5B").to(device)
|
|
247
|
+
>>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B")
|
|
248
|
+
|
|
249
|
+
>>> conversation = [
|
|
250
|
+
{
|
|
251
|
+
"role": "user",
|
|
252
|
+
"content": [
|
|
253
|
+
{"type": "text", "text": "What are these?"},
|
|
254
|
+
{"type": "image"}
|
|
255
|
+
]
|
|
256
|
+
}
|
|
257
|
+
]
|
|
258
|
+
|
|
259
|
+
>>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
|
260
|
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
261
|
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
262
|
+
|
|
263
|
+
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
|
|
264
|
+
|
|
265
|
+
>>> # Generate
|
|
266
|
+
>>> generated_ids = model.generate(**inputs, max_new_tokens=15)
|
|
267
|
+
>>> print(processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])
|
|
268
|
+
system\n You are a helpful assistant.\n user\n What are these?\n assistant\n The image depicts a traditional Chinese street...
|
|
269
|
+
```"""
|
|
270
|
+
super().forward(**super_kwargs)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
__all__ = ["FastVlmForConditionalGeneration", "FastVlmModel", "FastVlmPreTrainedModel", "FastVlmConfig"]
|
|
@@ -514,7 +514,7 @@ class FastSpeech2ConformerConvolutionModule(nn.Module):
|
|
|
514
514
|
|
|
515
515
|
Args:
|
|
516
516
|
hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
|
|
517
|
-
attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
|
|
517
|
+
attention_mask (`torch.Tensor` of shape `(batch, 1, time, time)`): Attention mask.
|
|
518
518
|
|
|
519
519
|
Returns:
|
|
520
520
|
`torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
|
|
@@ -530,7 +530,10 @@ class FastSpeech2ConformerConvolutionModule(nn.Module):
|
|
|
530
530
|
|
|
531
531
|
# Apply padding mask before convolution
|
|
532
532
|
if attention_mask is not None:
|
|
533
|
-
|
|
533
|
+
if attention_mask.dtype == torch.bool:
|
|
534
|
+
all_masked_rows = torch.all(~attention_mask, dim=2)
|
|
535
|
+
else:
|
|
536
|
+
all_masked_rows = torch.all(~(attention_mask == 0.0), dim=2)
|
|
534
537
|
hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
|
|
535
538
|
|
|
536
539
|
# 1D Depthwise Conv
|
|
@@ -724,19 +727,20 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
|
|
|
724
727
|
self.embed_dim = config.hidden_size
|
|
725
728
|
self.input_scale = math.sqrt(self.embed_dim)
|
|
726
729
|
self.dropout = nn.Dropout(p=module_config["positional_dropout_rate"])
|
|
727
|
-
self.pos_enc = None
|
|
728
730
|
self.max_len = 5000
|
|
729
|
-
self.
|
|
731
|
+
self.register_buffer(
|
|
732
|
+
"pos_enc", self.extend_pos_enc(torch.tensor(0.0).expand(1, self.max_len)), persistent=False
|
|
733
|
+
)
|
|
730
734
|
|
|
731
|
-
def extend_pos_enc(self, x):
|
|
735
|
+
def extend_pos_enc(self, x, pos_enc=None):
|
|
732
736
|
"""Reset the positional encodings."""
|
|
733
|
-
if
|
|
737
|
+
if pos_enc is not None:
|
|
734
738
|
# self.pos_enc contains both positive and negative parts
|
|
735
739
|
# the length of self.pos_enc is 2 * input_len - 1
|
|
736
|
-
if
|
|
737
|
-
if
|
|
738
|
-
|
|
739
|
-
return
|
|
740
|
+
if pos_enc.size(1) >= x.size(1) * 2 - 1:
|
|
741
|
+
if pos_enc.dtype != x.dtype or pos_enc.device != x.device:
|
|
742
|
+
pos_enc = pos_enc.to(dtype=x.dtype, device=x.device)
|
|
743
|
+
return pos_enc
|
|
740
744
|
# Suppose `i` means to the position of query vector and `j` means the
|
|
741
745
|
# position of key vector. We use position relative positions when keys
|
|
742
746
|
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
|
@@ -757,7 +761,7 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
|
|
|
757
761
|
pos_enc_positive = torch.flip(pos_enc_positive, [0]).unsqueeze(0)
|
|
758
762
|
pos_enc_negative = pos_enc_negative[1:].unsqueeze(0)
|
|
759
763
|
pos_enc = torch.cat([pos_enc_positive, pos_enc_negative], dim=1)
|
|
760
|
-
|
|
764
|
+
return pos_enc.to(device=x.device, dtype=x.dtype)
|
|
761
765
|
|
|
762
766
|
def forward(self, feature_representation):
|
|
763
767
|
"""
|
|
@@ -768,7 +772,7 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
|
|
|
768
772
|
Returns:
|
|
769
773
|
`torch.Tensor`: Encoded tensor (batch_size, time, `*`).
|
|
770
774
|
"""
|
|
771
|
-
self.extend_pos_enc(feature_representation)
|
|
775
|
+
self.pos_enc = self.extend_pos_enc(feature_representation, self.pos_enc)
|
|
772
776
|
hidden_states = feature_representation * self.input_scale
|
|
773
777
|
center_idx = self.pos_enc.size(1) // 2
|
|
774
778
|
pos_emb = self.pos_enc[:, center_idx - hidden_states.size(1) + 1 : center_idx + hidden_states.size(1)]
|
|
@@ -1007,6 +1011,10 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel):
|
|
|
1007
1011
|
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
|
|
1008
1012
|
init.zeros_(module.bias)
|
|
1009
1013
|
init.ones_(module.weight)
|
|
1014
|
+
if getattr(module, "running_mean", None) is not None:
|
|
1015
|
+
init.zeros_(module.running_mean)
|
|
1016
|
+
init.ones_(module.running_var)
|
|
1017
|
+
init.zeros_(module.num_batches_tracked)
|
|
1010
1018
|
elif isinstance(module, nn.Embedding):
|
|
1011
1019
|
init.normal_(module.weight)
|
|
1012
1020
|
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
@@ -1015,6 +1023,8 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel):
|
|
|
1015
1023
|
elif isinstance(module, FastSpeech2ConformerAttention):
|
|
1016
1024
|
init.xavier_uniform_(module.pos_bias_u)
|
|
1017
1025
|
init.xavier_uniform_(module.pos_bias_v)
|
|
1026
|
+
elif isinstance(module, FastSpeech2ConformerRelPositionalEncoding):
|
|
1027
|
+
init.copy_(module.pos_enc, module.extend_pos_enc(torch.tensor(0.0).expand(1, module.max_len)))
|
|
1018
1028
|
|
|
1019
1029
|
def _set_gradient_checkpointing(self, module, value=False):
|
|
1020
1030
|
if isinstance(module, FastSpeech2ConformerEncoder):
|
|
@@ -1118,6 +1128,7 @@ class FastSpeech2ConformerModel(FastSpeech2ConformerPreTrainedModel):
|
|
|
1118
1128
|
return_dict: Optional[bool] = None,
|
|
1119
1129
|
output_attentions: Optional[bool] = None,
|
|
1120
1130
|
output_hidden_states: Optional[bool] = None,
|
|
1131
|
+
**kwargs,
|
|
1121
1132
|
) -> Union[tuple, FastSpeech2ConformerModelOutput]:
|
|
1122
1133
|
r"""
|
|
1123
1134
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -1406,6 +1417,12 @@ class FastSpeech2ConformerHifiGan(PreTrainedModel):
|
|
|
1406
1417
|
# Initialize weights and apply final processing
|
|
1407
1418
|
self.post_init()
|
|
1408
1419
|
|
|
1420
|
+
def _init_weights(self, module):
|
|
1421
|
+
super()._init_weights(module)
|
|
1422
|
+
if isinstance(module, FastSpeech2ConformerHifiGan):
|
|
1423
|
+
init.zeros_(module.mean)
|
|
1424
|
+
init.ones_(module.scale)
|
|
1425
|
+
|
|
1409
1426
|
def apply_weight_norm(self):
|
|
1410
1427
|
weight_norm = nn.utils.weight_norm
|
|
1411
1428
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
|
@@ -1433,7 +1450,7 @@ class FastSpeech2ConformerHifiGan(PreTrainedModel):
|
|
|
1433
1450
|
waveform.
|
|
1434
1451
|
"""
|
|
1435
1452
|
)
|
|
1436
|
-
def forward(self, spectrogram: torch.FloatTensor) -> torch.FloatTensor:
|
|
1453
|
+
def forward(self, spectrogram: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
|
1437
1454
|
r"""
|
|
1438
1455
|
spectrogram (`torch.FloatTensor`):
|
|
1439
1456
|
Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length,
|
|
@@ -1509,6 +1526,7 @@ class FastSpeech2ConformerWithHifiGan(PreTrainedModel):
|
|
|
1509
1526
|
return_dict: Optional[bool] = None,
|
|
1510
1527
|
output_attentions: Optional[bool] = None,
|
|
1511
1528
|
output_hidden_states: Optional[bool] = None,
|
|
1529
|
+
**kwargs,
|
|
1512
1530
|
) -> Union[tuple, FastSpeech2ConformerModelOutput]:
|
|
1513
1531
|
r"""
|
|
1514
1532
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
@@ -660,9 +660,6 @@ class FlaubertPreTrainedModel(PreTrainedModel):
|
|
|
660
660
|
config: FlaubertConfig
|
|
661
661
|
base_model_prefix = "transformer"
|
|
662
662
|
|
|
663
|
-
def __init__(self, *inputs, **kwargs):
|
|
664
|
-
super().__init__(*inputs, **kwargs)
|
|
665
|
-
|
|
666
663
|
@property
|
|
667
664
|
def dummy_inputs(self):
|
|
668
665
|
inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
|
@@ -690,15 +687,17 @@ class FlaubertPreTrainedModel(PreTrainedModel):
|
|
|
690
687
|
if isinstance(module, nn.LayerNorm):
|
|
691
688
|
init.zeros_(module.bias)
|
|
692
689
|
init.ones_(module.weight)
|
|
693
|
-
if isinstance(module, FlaubertModel)
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
690
|
+
if isinstance(module, FlaubertModel):
|
|
691
|
+
if self.config.sinusoidal_embeddings:
|
|
692
|
+
init.copy_(
|
|
693
|
+
module.position_embeddings.weight,
|
|
694
|
+
create_sinusoidal_embeddings(
|
|
695
|
+
self.config.max_position_embeddings,
|
|
696
|
+
self.config.emb_dim,
|
|
697
|
+
out=torch.empty_like(module.position_embeddings.weight),
|
|
698
|
+
),
|
|
699
|
+
)
|
|
700
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
702
701
|
|
|
703
702
|
|
|
704
703
|
@auto_docstring
|
|
@@ -760,15 +759,15 @@ class FlaubertModel(FlaubertPreTrainedModel):
|
|
|
760
759
|
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
|
|
761
760
|
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
|
762
761
|
|
|
763
|
-
# Initialize weights and apply final processing
|
|
764
|
-
self.post_init()
|
|
765
|
-
|
|
766
762
|
self.layerdrop = getattr(config, "layerdrop", 0.0)
|
|
767
763
|
self.pre_norm = getattr(config, "pre_norm", False)
|
|
768
764
|
self.register_buffer(
|
|
769
765
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
|
770
766
|
)
|
|
771
767
|
|
|
768
|
+
# Initialize weights and apply final processing
|
|
769
|
+
self.post_init()
|
|
770
|
+
|
|
772
771
|
# Copied from transformers.models.xlm.modeling_xlm.XLMModel.get_input_embeddings
|
|
773
772
|
def get_input_embeddings(self):
|
|
774
773
|
return self.embeddings
|
|
@@ -792,6 +791,7 @@ class FlaubertModel(FlaubertPreTrainedModel):
|
|
|
792
791
|
output_hidden_states: Optional[bool] = None,
|
|
793
792
|
return_dict: Optional[bool] = None,
|
|
794
793
|
cache_position: Optional[torch.Tensor] = None,
|
|
794
|
+
**kwargs,
|
|
795
795
|
) -> Union[tuple, BaseModelOutput]:
|
|
796
796
|
r"""
|
|
797
797
|
langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1002,6 +1002,7 @@ class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin):
|
|
|
1002
1002
|
output_attentions: Optional[bool] = None,
|
|
1003
1003
|
output_hidden_states: Optional[bool] = None,
|
|
1004
1004
|
return_dict: Optional[bool] = None,
|
|
1005
|
+
**kwargs,
|
|
1005
1006
|
) -> Union[tuple, MaskedLMOutput]:
|
|
1006
1007
|
r"""
|
|
1007
1008
|
langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1090,6 +1091,7 @@ class FlaubertForSequenceClassification(FlaubertPreTrainedModel):
|
|
|
1090
1091
|
output_attentions: Optional[bool] = None,
|
|
1091
1092
|
output_hidden_states: Optional[bool] = None,
|
|
1092
1093
|
return_dict: Optional[bool] = None,
|
|
1094
|
+
**kwargs,
|
|
1093
1095
|
) -> Union[tuple, SequenceClassifierOutput]:
|
|
1094
1096
|
r"""
|
|
1095
1097
|
langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1195,6 +1197,7 @@ class FlaubertForTokenClassification(FlaubertPreTrainedModel):
|
|
|
1195
1197
|
output_attentions: Optional[bool] = None,
|
|
1196
1198
|
output_hidden_states: Optional[bool] = None,
|
|
1197
1199
|
return_dict: Optional[bool] = None,
|
|
1200
|
+
**kwargs,
|
|
1198
1201
|
) -> Union[tuple, TokenClassifierOutput]:
|
|
1199
1202
|
r"""
|
|
1200
1203
|
langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1286,6 +1289,7 @@ class FlaubertForQuestionAnsweringSimple(FlaubertPreTrainedModel):
|
|
|
1286
1289
|
output_attentions: Optional[bool] = None,
|
|
1287
1290
|
output_hidden_states: Optional[bool] = None,
|
|
1288
1291
|
return_dict: Optional[bool] = None,
|
|
1292
|
+
**kwargs,
|
|
1289
1293
|
) -> Union[tuple, QuestionAnsweringModelOutput]:
|
|
1290
1294
|
r"""
|
|
1291
1295
|
langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1423,6 +1427,7 @@ class FlaubertForQuestionAnswering(FlaubertPreTrainedModel):
|
|
|
1423
1427
|
output_attentions: Optional[bool] = None,
|
|
1424
1428
|
output_hidden_states: Optional[bool] = None,
|
|
1425
1429
|
return_dict: Optional[bool] = None,
|
|
1430
|
+
**kwargs,
|
|
1426
1431
|
) -> Union[tuple, FlaubertForQuestionAnsweringOutput]:
|
|
1427
1432
|
r"""
|
|
1428
1433
|
langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1538,6 +1543,7 @@ class FlaubertForMultipleChoice(FlaubertPreTrainedModel):
|
|
|
1538
1543
|
output_attentions: Optional[bool] = None,
|
|
1539
1544
|
output_hidden_states: Optional[bool] = None,
|
|
1540
1545
|
return_dict: Optional[bool] = None,
|
|
1546
|
+
**kwargs,
|
|
1541
1547
|
) -> Union[tuple, MultipleChoiceModelOutput]:
|
|
1542
1548
|
r"""
|
|
1543
1549
|
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
|
@@ -306,7 +306,6 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
|
|
|
306
306
|
processed_images_grouped[shape] = stacked_images
|
|
307
307
|
|
|
308
308
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
309
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
310
309
|
|
|
311
310
|
return processed_images
|
|
312
311
|
|
|
@@ -397,7 +396,6 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
|
|
|
397
396
|
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
|
|
398
397
|
)
|
|
399
398
|
masks = [mask_generator() for _ in range(len(images))]
|
|
400
|
-
masks = torch.stack(masks, dim=0) if return_tensors else masks
|
|
401
399
|
data["bool_masked_pos"] = masks
|
|
402
400
|
|
|
403
401
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
@@ -677,6 +677,9 @@ class FlavaPreTrainedModel(PreTrainedModel):
|
|
|
677
677
|
init.zeros_(module.position_embeddings)
|
|
678
678
|
if module.mask_token is not None:
|
|
679
679
|
init.zeros_(module.mask_token)
|
|
680
|
+
elif isinstance(module, FlavaTextEmbeddings):
|
|
681
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
682
|
+
init.zeros_(module.token_type_ids)
|
|
680
683
|
elif isinstance(module, FlavaMultimodalModel):
|
|
681
684
|
if module.use_cls_token:
|
|
682
685
|
init.zeros_(module.cls_token)
|
|
@@ -725,6 +728,7 @@ class FlavaImageModel(FlavaPreTrainedModel):
|
|
|
725
728
|
output_attentions: Optional[bool] = None,
|
|
726
729
|
output_hidden_states: Optional[bool] = None,
|
|
727
730
|
return_dict: Optional[bool] = None,
|
|
731
|
+
**kwargs,
|
|
728
732
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
|
729
733
|
r"""
|
|
730
734
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
|
|
@@ -804,6 +808,7 @@ class FlavaTextModel(FlavaPreTrainedModel):
|
|
|
804
808
|
output_attentions: Optional[bool] = None,
|
|
805
809
|
output_hidden_states: Optional[bool] = None,
|
|
806
810
|
return_dict: Optional[bool] = None,
|
|
811
|
+
**kwargs,
|
|
807
812
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
|
808
813
|
r"""
|
|
809
814
|
input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`):
|
|
@@ -896,6 +901,7 @@ class FlavaMultimodalModel(FlavaPreTrainedModel):
|
|
|
896
901
|
output_attentions: Optional[bool] = None,
|
|
897
902
|
output_hidden_states: Optional[bool] = None,
|
|
898
903
|
return_dict: Optional[bool] = None,
|
|
904
|
+
**kwargs,
|
|
899
905
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
|
900
906
|
r"""
|
|
901
907
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`):
|
|
@@ -1103,7 +1109,8 @@ class FlavaModel(FlavaPreTrainedModel):
|
|
|
1103
1109
|
output_attentions: Optional[bool] = None,
|
|
1104
1110
|
output_hidden_states: bool = True,
|
|
1105
1111
|
return_dict: Optional[bool] = None,
|
|
1106
|
-
|
|
1112
|
+
**kwargs,
|
|
1113
|
+
) -> Union[tuple, FlavaModelOutput]:
|
|
1107
1114
|
r"""
|
|
1108
1115
|
input_ids (`torch.LongTensor` of shape `(batch_size, image_num_patches + text_seq_len)`):
|
|
1109
1116
|
Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
|
|
@@ -1380,7 +1387,7 @@ class FlavaImageCodebook(FlavaPreTrainedModel):
|
|
|
1380
1387
|
z_logits = self.blocks(pixel_values)
|
|
1381
1388
|
return nn.Softmax(dim=1)(z_logits)
|
|
1382
1389
|
|
|
1383
|
-
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
|
1390
|
+
def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> torch.Tensor:
|
|
1384
1391
|
f"""
|
|
1385
1392
|
Args:
|
|
1386
1393
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
@@ -1575,6 +1582,7 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
|
|
|
1575
1582
|
output_hidden_states: bool = True,
|
|
1576
1583
|
return_dict: Optional[bool] = None,
|
|
1577
1584
|
return_loss: Optional[bool] = None,
|
|
1585
|
+
**kwargs,
|
|
1578
1586
|
) -> Union[tuple[torch.Tensor], FlavaForPreTrainingOutput]:
|
|
1579
1587
|
r"""
|
|
1580
1588
|
input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_len)`):
|
|
@@ -30,15 +30,15 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
33
|
+
from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask
|
|
35
35
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
36
36
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
37
37
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
38
38
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
41
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
40
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
41
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
42
42
|
from .configuration_flex_olmo import FlexOlmoConfig
|
|
43
43
|
|
|
44
44
|
|
|
@@ -80,7 +80,7 @@ class FlexOlmoRotaryEmbedding(nn.Module):
|
|
|
80
80
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
81
81
|
|
|
82
82
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
83
|
-
self.original_inv_freq =
|
|
83
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
84
84
|
|
|
85
85
|
@staticmethod
|
|
86
86
|
def compute_default_rope_parameters(
|
|
@@ -119,7 +119,7 @@ class FlexOlmoRotaryEmbedding(nn.Module):
|
|
|
119
119
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
120
120
|
|
|
121
121
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
122
|
-
with
|
|
122
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
123
123
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
124
124
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
125
125
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -216,6 +216,7 @@ def rotate_half(x):
|
|
|
216
216
|
return torch.cat((-x2, x1), dim=-1)
|
|
217
217
|
|
|
218
218
|
|
|
219
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
219
220
|
class FlexOlmoAttention(nn.Module):
|
|
220
221
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
221
222
|
|
|
@@ -241,7 +242,6 @@ class FlexOlmoAttention(nn.Module):
|
|
|
241
242
|
self.o_proj = nn.Linear(
|
|
242
243
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
243
244
|
)
|
|
244
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
245
245
|
self.q_norm = FlexOlmoRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
|
|
246
246
|
self.k_norm = FlexOlmoRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
|
|
247
247
|
|
|
@@ -252,7 +252,6 @@ class FlexOlmoAttention(nn.Module):
|
|
|
252
252
|
attention_mask: Optional[torch.Tensor],
|
|
253
253
|
past_key_values: Optional[Cache] = None,
|
|
254
254
|
cache_position: Optional[torch.LongTensor] = None,
|
|
255
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
256
255
|
**kwargs: Unpack[TransformersKwargs],
|
|
257
256
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
258
257
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -294,6 +293,7 @@ class FlexOlmoAttention(nn.Module):
|
|
|
294
293
|
return attn_output, attn_weights
|
|
295
294
|
|
|
296
295
|
|
|
296
|
+
@use_experts_implementation
|
|
297
297
|
class FlexOlmoExperts(nn.Module):
|
|
298
298
|
"""Collection of expert weights stored as 3D tensors."""
|
|
299
299
|
|
|
@@ -422,7 +422,9 @@ class FlexOlmoPreTrainedModel(PreTrainedModel):
|
|
|
422
422
|
_supports_flash_attn = True
|
|
423
423
|
_supports_sdpa = True
|
|
424
424
|
_supports_flex_attn = True
|
|
425
|
-
_can_compile_fullgraph =
|
|
425
|
+
_can_compile_fullgraph = (
|
|
426
|
+
is_grouped_mm_available()
|
|
427
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
426
428
|
_supports_attention_backend = True
|
|
427
429
|
_can_record_outputs = {
|
|
428
430
|
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
|