transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
2
|
+
# This file was automatically generated from src/transformers/models/glmasr/modular_glmasr.py.
|
|
3
|
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
4
|
+
# the file from the modular. If any change should be done, please apply the change to the
|
|
5
|
+
# modular_glmasr.py file directly. One of our CI enforces this.
|
|
6
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
7
|
+
# coding=utf-8
|
|
8
|
+
# Copyright 2025 the HuggingFace Team. All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
11
|
+
# you may not use this file except in compliance with the License.
|
|
12
|
+
# You may obtain a copy of the License at
|
|
13
|
+
#
|
|
14
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
15
|
+
#
|
|
16
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
17
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
18
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
19
|
+
# See the License for the specific language governing permissions and
|
|
20
|
+
# limitations under the License.
|
|
21
|
+
|
|
22
|
+
import re
|
|
23
|
+
from typing import Optional, Union
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
from ...audio_utils import AudioInput, make_list_of_audio
|
|
28
|
+
from ...feature_extraction_utils import BatchFeature
|
|
29
|
+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
|
30
|
+
from ...tokenization_utils_base import TextInput
|
|
31
|
+
from ...utils import is_torch_available, logging
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
if is_torch_available():
|
|
35
|
+
import torch
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
logger = logging.get_logger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class GlmAsrProcessorKwargs(ProcessingKwargs, total=False):
|
|
42
|
+
_defaults = {
|
|
43
|
+
"text_kwargs": {
|
|
44
|
+
"padding": True,
|
|
45
|
+
},
|
|
46
|
+
"audio_kwargs": {
|
|
47
|
+
"sampling_rate": 16000,
|
|
48
|
+
"chunk_length": 30.0,
|
|
49
|
+
"return_attention_mask": True,
|
|
50
|
+
"padding": "max_length",
|
|
51
|
+
},
|
|
52
|
+
"common_kwargs": {
|
|
53
|
+
"return_tensors": "pt",
|
|
54
|
+
"padding_side": "left",
|
|
55
|
+
},
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class GlmAsrProcessor(ProcessorMixin):
|
|
60
|
+
r"""
|
|
61
|
+
Constructs an GlmAsr processor which wraps an GlmAsr feature extractor and an GlmAsr
|
|
62
|
+
tokenizer into a single processor.
|
|
63
|
+
|
|
64
|
+
[`GlmAsrProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and
|
|
65
|
+
[`Qwen2TokenizerFast`]. See the [`~GlmAsrProcessor.__call__`] for more information.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
feature_extractor ([`WhisperFeatureExtractor`]):
|
|
69
|
+
The feature extractor is a required input.
|
|
70
|
+
tokenizer ([`Qwen2TokenizerFast`]):
|
|
71
|
+
The tokenizer is a required input.
|
|
72
|
+
chat_template (`Optional[str]`, *optional*):
|
|
73
|
+
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's default chat
|
|
74
|
+
template will be used.
|
|
75
|
+
audio_token (`Optional[str]`, *optional*, defaults to `"<|pad|>`"):
|
|
76
|
+
Special token used to represent audio inputs in the chat template.
|
|
77
|
+
default_transcription_prompt (`str`, *optional*, defaults to `"Please transcribe this audio into text"`):
|
|
78
|
+
Default prompt to use for transcription tasks when applying transcription requests.
|
|
79
|
+
max_audio_len (`int`, *optional*, defaults to 655):
|
|
80
|
+
Maximum length of audio sequences in seconds. Audio longer than this will be truncated.
|
|
81
|
+
655 gives approximately 8192 tokens, corresponding to the maximum sequence length of the text model.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
feature_extractor,
|
|
87
|
+
tokenizer,
|
|
88
|
+
chat_template=None,
|
|
89
|
+
audio_token="<|pad|>",
|
|
90
|
+
default_transcription_prompt="Please transcribe this audio into text",
|
|
91
|
+
max_audio_len=655,
|
|
92
|
+
):
|
|
93
|
+
self.audio_token = audio_token
|
|
94
|
+
self.audio_token_id = tokenizer.convert_tokens_to_ids(audio_token)
|
|
95
|
+
self.default_transcription_prompt = default_transcription_prompt
|
|
96
|
+
self.max_audio_len = max_audio_len
|
|
97
|
+
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
|
|
98
|
+
|
|
99
|
+
def _get_audio_token_length(self, audio_lengths: "torch.Tensor") -> "torch.Tensor":
|
|
100
|
+
merge_factor = 4
|
|
101
|
+
for padding, kernel_size, stride in [(1, 3, 1), (1, 3, 2)]:
|
|
102
|
+
audio_lengths = (audio_lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
|
|
103
|
+
|
|
104
|
+
num_tokens = (audio_lengths - merge_factor) // merge_factor + 1
|
|
105
|
+
return num_tokens
|
|
106
|
+
|
|
107
|
+
def __call__(
|
|
108
|
+
self,
|
|
109
|
+
text: Union[TextInput, list[TextInput]],
|
|
110
|
+
audio: Optional[AudioInput] = None,
|
|
111
|
+
output_labels: Optional[bool] = False,
|
|
112
|
+
**kwargs: Unpack[GlmAsrProcessorKwargs],
|
|
113
|
+
) -> BatchFeature:
|
|
114
|
+
r"""
|
|
115
|
+
Main method to prepare one or several text sequence(s) and audio waveform(s) for the model. This
|
|
116
|
+
method expands `<sound>` placeholders in the text based on the post-pool frame counts of the
|
|
117
|
+
audio windows, then tokenizes the provided strings as-is, and extracts log-mel features
|
|
118
|
+
with [`WhisperFeatureExtractor`]. If `audio` is `None`, no audio processing is performed and
|
|
119
|
+
the text is tokenized as-is (LM-only behavior).
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
text (`str` or `list[str]`):
|
|
123
|
+
Input sequence or batch of sequences.
|
|
124
|
+
audio (`np.ndarray` or `list[np.ndarray]`):
|
|
125
|
+
Input audio or batch of audios as NumPy arrays. If provided, there must be as many `text` inputs as
|
|
126
|
+
`audio` inputs.
|
|
127
|
+
output_labels (bool, *optional*, default=False):
|
|
128
|
+
Whether to return labels for training.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
[`BatchFeature`]: A dictionary with tokenized text (`input_ids`, `attention_mask`) and
|
|
132
|
+
audio features (`input_features`, `input_features_mask`).
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
# Merge defaults with user kwargs
|
|
136
|
+
call_kwargs = self._merge_kwargs(
|
|
137
|
+
GlmAsrProcessorKwargs,
|
|
138
|
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
139
|
+
**kwargs,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
text_kwargs = call_kwargs["text_kwargs"]
|
|
143
|
+
audio_kwargs = call_kwargs["audio_kwargs"]
|
|
144
|
+
return_tensors = text_kwargs.get("return_tensors")
|
|
145
|
+
if return_tensors != "pt":
|
|
146
|
+
raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
|
|
147
|
+
|
|
148
|
+
if isinstance(text, str):
|
|
149
|
+
text = [text]
|
|
150
|
+
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
|
|
151
|
+
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
|
152
|
+
|
|
153
|
+
audio_inputs = {}
|
|
154
|
+
if audio is not None:
|
|
155
|
+
audio = make_list_of_audio(audio)
|
|
156
|
+
if len(text) != len(audio):
|
|
157
|
+
raise ValueError(f"Got {len(text)} text but {len(audio)} audios; they must match 1:1.")
|
|
158
|
+
|
|
159
|
+
# Determine number of chunks per sample, and flatten
|
|
160
|
+
window_size = int(audio_kwargs["sampling_rate"] * audio_kwargs["chunk_length"])
|
|
161
|
+
max_windows = int(self.max_audio_len // audio_kwargs["chunk_length"])
|
|
162
|
+
|
|
163
|
+
per_sample_windows: list[int] = []
|
|
164
|
+
flat_chunks: list[np.ndarray] = []
|
|
165
|
+
|
|
166
|
+
for audio_el in audio:
|
|
167
|
+
n_samples = int(audio_el.shape[0])
|
|
168
|
+
n_win = max(1, (n_samples + window_size - 1) // window_size)
|
|
169
|
+
if n_win > max_windows:
|
|
170
|
+
logger.warning(
|
|
171
|
+
f"Audio duration ({n_samples / audio_kwargs['sampling_rate']:.1f}s) exceeds {self.max_audio_len}s; truncating to first {self.max_audio_len}s."
|
|
172
|
+
)
|
|
173
|
+
n_win = max_windows
|
|
174
|
+
per_sample_windows.append(n_win)
|
|
175
|
+
|
|
176
|
+
time_cap = min(n_samples, n_win * window_size)
|
|
177
|
+
for i in range(n_win):
|
|
178
|
+
start = i * window_size
|
|
179
|
+
end = min((i + 1) * window_size, time_cap)
|
|
180
|
+
flat_chunks.append(audio_el[start:end])
|
|
181
|
+
|
|
182
|
+
# Feature extraction
|
|
183
|
+
audio_inputs = self.feature_extractor(flat_chunks, **audio_kwargs)
|
|
184
|
+
padding_mask = audio_inputs.pop("attention_mask")
|
|
185
|
+
audio_inputs["input_features_mask"] = padding_mask
|
|
186
|
+
|
|
187
|
+
# Compute sequence lengths token counting
|
|
188
|
+
audio_lengths = torch.stack([s.sum() for s in torch.split(padding_mask.sum(-1), per_sample_windows)])
|
|
189
|
+
audio_tokens_lengths = self._get_audio_token_length(audio_lengths)
|
|
190
|
+
|
|
191
|
+
# expand audio tokens in text
|
|
192
|
+
for i, audio_length in enumerate(audio_tokens_lengths):
|
|
193
|
+
expanded = re.sub(re.escape(self.audio_token), self.audio_token * audio_length, text[i])
|
|
194
|
+
text[i] = expanded
|
|
195
|
+
|
|
196
|
+
# Tokenize
|
|
197
|
+
text_inputs = self.tokenizer(text, **text_kwargs)
|
|
198
|
+
|
|
199
|
+
data = {**text_inputs, **audio_inputs}
|
|
200
|
+
if output_labels:
|
|
201
|
+
labels = data["input_ids"].clone()
|
|
202
|
+
labels[labels == self.audio_token_id] = -100
|
|
203
|
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
|
204
|
+
data["labels"] = labels
|
|
205
|
+
|
|
206
|
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def model_input_names(self) -> list[str]:
|
|
210
|
+
tok_names = self.tokenizer.model_input_names
|
|
211
|
+
fea_names = self.feature_extractor.model_input_names
|
|
212
|
+
return list(dict.fromkeys(tok_names + fea_names + ["input_features_mask"]))
|
|
213
|
+
|
|
214
|
+
def apply_transcription_request(
|
|
215
|
+
self,
|
|
216
|
+
audio: Union[str, list[str], AudioInput],
|
|
217
|
+
prompt: Optional[Union[str, list[str]]] = None,
|
|
218
|
+
**kwargs: Unpack[GlmAsrProcessorKwargs],
|
|
219
|
+
) -> BatchFeature:
|
|
220
|
+
"""
|
|
221
|
+
Prepare inputs for automatic speech recognition without manually writing the default transcription prompt.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
audio (`str`, `list[str]`, `np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
|
225
|
+
Audio to transcribe. Strings are interpreted as local paths or URLs and will be loaded automatically by
|
|
226
|
+
the chat template loader; NumPy arrays and PyTorch tensors are forwarded directly.
|
|
227
|
+
prompt (`str` or `list[str]`, *optional*):
|
|
228
|
+
Custom prompt(s) to include in the user turn. A list must be the same length as the batch. When `None`,
|
|
229
|
+
each sample uses `"Transcribe the input speech."`.
|
|
230
|
+
**kwargs:
|
|
231
|
+
Additional keyword arguments forwarded to [`~AudioFlamingo3Processor.apply_chat_template`] (for example
|
|
232
|
+
`text_kwargs`, `audio_kwargs`, ...).
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
[`BatchFeature`]: Processor outputs ready to be passed to [`AudioFlamingo3ForConditionalGeneration.generate`].
|
|
236
|
+
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
if isinstance(audio, str):
|
|
240
|
+
audio_items: list[Union[str, np.ndarray]] = [audio]
|
|
241
|
+
elif isinstance(audio, (list, tuple)) and audio and all(isinstance(el, str) for el in audio):
|
|
242
|
+
audio_items = list(audio)
|
|
243
|
+
else:
|
|
244
|
+
audio_items = list(make_list_of_audio(audio))
|
|
245
|
+
if is_torch_available():
|
|
246
|
+
audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items]
|
|
247
|
+
|
|
248
|
+
batch_size = len(audio_items)
|
|
249
|
+
if batch_size == 0:
|
|
250
|
+
raise ValueError("`audio` must contain at least one sample.")
|
|
251
|
+
|
|
252
|
+
if prompt is None:
|
|
253
|
+
prompts = [self.default_transcription_prompt] * batch_size
|
|
254
|
+
elif isinstance(prompt, str):
|
|
255
|
+
prompts = [prompt] * batch_size
|
|
256
|
+
elif isinstance(prompt, (list, tuple)):
|
|
257
|
+
if len(prompt) != batch_size:
|
|
258
|
+
raise ValueError(
|
|
259
|
+
f"Received {len(prompt)} prompt(s) for {batch_size} audio sample(s); counts must match."
|
|
260
|
+
)
|
|
261
|
+
prompts = []
|
|
262
|
+
for item in prompt:
|
|
263
|
+
if item is None:
|
|
264
|
+
prompts.append(self.default_transcription_prompt)
|
|
265
|
+
elif isinstance(item, str):
|
|
266
|
+
prompts.append(item)
|
|
267
|
+
else:
|
|
268
|
+
raise TypeError("Each prompt must be a string or `None`.")
|
|
269
|
+
else:
|
|
270
|
+
raise TypeError("`prompt` must be a string, a sequence of strings, or `None`.")
|
|
271
|
+
|
|
272
|
+
conversations = [
|
|
273
|
+
[
|
|
274
|
+
{
|
|
275
|
+
"role": "user",
|
|
276
|
+
"content": [
|
|
277
|
+
{"type": "audio", "path": audio_item}
|
|
278
|
+
if isinstance(audio_item, str)
|
|
279
|
+
else {"type": "audio", "audio": audio_item},
|
|
280
|
+
{"type": "text", "text": prompt_text},
|
|
281
|
+
],
|
|
282
|
+
}
|
|
283
|
+
]
|
|
284
|
+
for prompt_text, audio_item in zip(prompts, audio_items)
|
|
285
|
+
]
|
|
286
|
+
|
|
287
|
+
return self.apply_chat_template(
|
|
288
|
+
conversations,
|
|
289
|
+
tokenize=True,
|
|
290
|
+
add_generation_prompt=True,
|
|
291
|
+
return_dict=True,
|
|
292
|
+
**kwargs,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
def batch_decode(self, *args, strip_prefix=False, **kwargs):
|
|
296
|
+
"""
|
|
297
|
+
Forward arguments to [`~PreTrainedTokenizer.batch_decode`] and optionally remove the assistant framing the model
|
|
298
|
+
was trained to produce.
|
|
299
|
+
|
|
300
|
+
AF3 transcription requests respond with sentences such as `"The spoken content of the audio is \"...\"."`.
|
|
301
|
+
Setting `strip_prefix=True` trims the fixed prefix for just the transcription text.
|
|
302
|
+
"""
|
|
303
|
+
decoded = self.tokenizer.batch_decode(*args, **kwargs)
|
|
304
|
+
if strip_prefix:
|
|
305
|
+
decoded = [self._strip_assistant_prefix_and_quotes(text) for text in decoded]
|
|
306
|
+
return decoded
|
|
307
|
+
|
|
308
|
+
def _strip_assistant_prefix_and_quotes(self, text: str) -> str:
|
|
309
|
+
"""
|
|
310
|
+
Remove the assistant prefix and surrounding quotes from a decoded transcription string.
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
stripped = text.strip()
|
|
314
|
+
|
|
315
|
+
for prefix in (
|
|
316
|
+
"The spoken content of the audio is",
|
|
317
|
+
"The transcription of the audio is",
|
|
318
|
+
):
|
|
319
|
+
if stripped.startswith(prefix):
|
|
320
|
+
stripped = stripped[len(prefix) :].strip()
|
|
321
|
+
break
|
|
322
|
+
|
|
323
|
+
if stripped.endswith("."):
|
|
324
|
+
stripped = stripped[:-1].strip()
|
|
325
|
+
|
|
326
|
+
if len(stripped) >= 2 and stripped[0] == stripped[-1] and stripped[0] in {"'", '"'}:
|
|
327
|
+
stripped = stripped[1:-1].strip()
|
|
328
|
+
|
|
329
|
+
return stripped
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
__all__ = ["GlmAsrProcessor"]
|
|
@@ -107,7 +107,6 @@ class GLPNImageProcessorFast(BaseImageProcessorFast):
|
|
|
107
107
|
processed_groups[shape] = stacked_images
|
|
108
108
|
|
|
109
109
|
processed_images = reorder_images(processed_groups, grouped_index)
|
|
110
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
111
110
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
112
111
|
|
|
113
112
|
def post_process_depth_estimation(self, outputs, target_sizes=None):
|
|
@@ -189,7 +189,6 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
|
|
|
189
189
|
processed_images_grouped[shape] = stacked_images
|
|
190
190
|
|
|
191
191
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
192
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
193
192
|
|
|
194
193
|
return BatchFeature(
|
|
195
194
|
data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors
|
|
@@ -433,6 +433,7 @@ class GotOcr2VisionEncoder(GotOcr2PreTrainedModel):
|
|
|
433
433
|
self.neck = GotOcr2VisionNeck(config)
|
|
434
434
|
|
|
435
435
|
self.gradient_checkpointing = False
|
|
436
|
+
self.post_init()
|
|
436
437
|
|
|
437
438
|
def get_input_embeddings(self):
|
|
438
439
|
return self.patch_embed
|
|
@@ -796,6 +797,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
|
|
|
796
797
|
attention_mask=None,
|
|
797
798
|
cache_position=None,
|
|
798
799
|
logits_to_keep=None,
|
|
800
|
+
is_first_iteration=False,
|
|
799
801
|
**kwargs,
|
|
800
802
|
):
|
|
801
803
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -807,12 +809,15 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
|
|
|
807
809
|
attention_mask=attention_mask,
|
|
808
810
|
cache_position=cache_position,
|
|
809
811
|
logits_to_keep=logits_to_keep,
|
|
812
|
+
is_first_iteration=is_first_iteration,
|
|
810
813
|
**kwargs,
|
|
811
814
|
)
|
|
812
815
|
|
|
813
|
-
if
|
|
814
|
-
#
|
|
815
|
-
#
|
|
816
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
817
|
+
# Pixel values are used only in the first iteration if available
|
|
818
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
819
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
820
|
+
# iteration with a question and cached system prompt (continue generate from cache)
|
|
816
821
|
model_inputs["pixel_values"] = pixel_values
|
|
817
822
|
|
|
818
823
|
return model_inputs
|
|
@@ -103,7 +103,6 @@ class GPT2Attention(nn.Module):
|
|
|
103
103
|
),
|
|
104
104
|
persistent=False,
|
|
105
105
|
)
|
|
106
|
-
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
|
|
107
106
|
|
|
108
107
|
self.embed_dim = config.hidden_size
|
|
109
108
|
self.num_heads = config.num_attention_heads
|
|
@@ -476,12 +475,8 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
|
|
476
475
|
_supports_flash_attn = True
|
|
477
476
|
_supports_sdpa = True
|
|
478
477
|
_supports_attention_backend = True
|
|
479
|
-
|
|
480
478
|
_can_compile_fullgraph = True
|
|
481
479
|
|
|
482
|
-
def __init__(self, *inputs, **kwargs):
|
|
483
|
-
super().__init__(*inputs, **kwargs)
|
|
484
|
-
|
|
485
480
|
@torch.no_grad()
|
|
486
481
|
def _init_weights(self, module):
|
|
487
482
|
"""Initialize the weights."""
|
|
@@ -497,6 +492,14 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
|
|
497
492
|
elif isinstance(module, nn.LayerNorm):
|
|
498
493
|
init.zeros_(module.bias)
|
|
499
494
|
init.ones_(module.weight)
|
|
495
|
+
elif isinstance(module, GPT2Attention):
|
|
496
|
+
max_positions = module.config.max_position_embeddings
|
|
497
|
+
init.copy_(
|
|
498
|
+
module.bias,
|
|
499
|
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
|
500
|
+
1, 1, max_positions, max_positions
|
|
501
|
+
),
|
|
502
|
+
)
|
|
500
503
|
|
|
501
504
|
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
|
502
505
|
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
|
@@ -26,7 +26,6 @@ from ...activations import ACT2FN
|
|
|
26
26
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
27
27
|
from ...generation import GenerationMixin
|
|
28
28
|
from ...masking_utils import create_causal_mask
|
|
29
|
-
from ...modeling_flash_attention_utils import is_flash_attn_available
|
|
30
29
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
31
30
|
from ...modeling_outputs import (
|
|
32
31
|
BaseModelOutputWithPastAndCrossAttentions,
|
|
@@ -43,10 +42,6 @@ from ...utils import (
|
|
|
43
42
|
from .configuration_gpt_bigcode import GPTBigCodeConfig
|
|
44
43
|
|
|
45
44
|
|
|
46
|
-
if is_flash_attn_available():
|
|
47
|
-
pass
|
|
48
|
-
|
|
49
|
-
|
|
50
45
|
logger = logging.get_logger(__name__)
|
|
51
46
|
|
|
52
47
|
|
|
@@ -360,9 +355,6 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
|
|
360
355
|
_supports_flash_attn = True
|
|
361
356
|
_supports_sdpa = True
|
|
362
357
|
|
|
363
|
-
def __init__(self, *inputs, **kwargs):
|
|
364
|
-
super().__init__(*inputs, **kwargs)
|
|
365
|
-
|
|
366
358
|
@torch.no_grad()
|
|
367
359
|
def _init_weights(self, module):
|
|
368
360
|
"""Initialize the weights."""
|
|
@@ -377,6 +369,9 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
|
|
377
369
|
init.normal_(
|
|
378
370
|
module.c_proj.weight, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.n_layer)
|
|
379
371
|
)
|
|
372
|
+
elif isinstance(module, GPTBigCodeModel):
|
|
373
|
+
max_positions = module.config.max_position_embeddings
|
|
374
|
+
init.copy_(module.bias, torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)))
|
|
380
375
|
|
|
381
376
|
|
|
382
377
|
@auto_docstring
|
|
@@ -20,6 +20,7 @@ import torch
|
|
|
20
20
|
from torch import nn
|
|
21
21
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
22
22
|
|
|
23
|
+
from ... import initialization as init
|
|
23
24
|
from ...activations import ACT2FN
|
|
24
25
|
from ...cache_utils import Cache, DynamicCache
|
|
25
26
|
from ...generation import GenerationMixin
|
|
@@ -70,11 +71,11 @@ class GPTNeoSelfAttention(nn.Module):
|
|
|
70
71
|
# local causal self attention is a sliding window where each token can only attend to the previous
|
|
71
72
|
# window_size tokens. This is implemented by updating the causal mask such that for each token
|
|
72
73
|
# all other tokens are masked except the previous window_size tokens.
|
|
74
|
+
self.attention_type = attention_type
|
|
73
75
|
if attention_type == "local":
|
|
74
76
|
bias = torch.bitwise_xor(bias, torch.tril(bias, -config.window_size))
|
|
75
77
|
|
|
76
78
|
self.register_buffer("bias", bias, persistent=False)
|
|
77
|
-
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
|
78
79
|
|
|
79
80
|
self.attn_dropout = nn.Dropout(float(config.attention_dropout))
|
|
80
81
|
self.resid_dropout = nn.Dropout(float(config.resid_dropout))
|
|
@@ -237,8 +238,8 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
|
|
|
237
238
|
else torch.get_autocast_gpu_dtype()
|
|
238
239
|
)
|
|
239
240
|
# Handle the case where the model is quantized
|
|
240
|
-
elif hasattr(self.config, "
|
|
241
|
-
target_dtype = self.config.
|
|
241
|
+
elif hasattr(self.config, "quantization_config"):
|
|
242
|
+
target_dtype = self.config.dtype
|
|
242
243
|
else:
|
|
243
244
|
target_dtype = self.q_proj.weight.dtype
|
|
244
245
|
|
|
@@ -382,6 +383,17 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
|
|
|
382
383
|
_supports_flash_attn = True
|
|
383
384
|
_can_compile_fullgraph = False # TODO: needs a hybrid cache
|
|
384
385
|
|
|
386
|
+
def _init_weights(self, module):
|
|
387
|
+
super()._init_weights(module)
|
|
388
|
+
if isinstance(module, GPTNeoSelfAttention):
|
|
389
|
+
max_positions = module.config.max_position_embeddings
|
|
390
|
+
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view(
|
|
391
|
+
1, 1, max_positions, max_positions
|
|
392
|
+
)
|
|
393
|
+
if module.attention_type == "local":
|
|
394
|
+
bias = torch.bitwise_xor(bias, torch.tril(bias, -module.config.window_size))
|
|
395
|
+
init.copy_(module.bias, bias)
|
|
396
|
+
|
|
385
397
|
|
|
386
398
|
@auto_docstring
|
|
387
399
|
class GPTNeoModel(GPTNeoPreTrainedModel):
|
|
@@ -66,7 +66,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
|
|
66
66
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
67
67
|
|
|
68
68
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
69
|
-
self.original_inv_freq =
|
|
69
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
70
70
|
|
|
71
71
|
@staticmethod
|
|
72
72
|
def compute_default_rope_parameters(
|
|
@@ -78,7 +78,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
|
|
|
78
78
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
79
79
|
|
|
80
80
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
81
|
-
self.original_inv_freq =
|
|
81
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
82
82
|
|
|
83
83
|
@staticmethod
|
|
84
84
|
def compute_default_rope_parameters(
|
|
@@ -117,5 +117,22 @@ class GptOssConfig(PreTrainedConfig):
|
|
|
117
117
|
**kwargs,
|
|
118
118
|
)
|
|
119
119
|
|
|
120
|
+
def __setattr__(self, key, value):
|
|
121
|
+
"""
|
|
122
|
+
Overwritten to allow checking for the proper attention implementation to be used.
|
|
123
|
+
|
|
124
|
+
Due to `set_attn_implementation` which internally assigns `_attn_implementation_internal = "..."`, simply overwriting
|
|
125
|
+
the specific attention setter is not enough. Using a property/setter for `_attn_implementation_internal` would result in
|
|
126
|
+
a recursive dependency (as `_attn_implementation` acts as a wrapper around `_attn_implementation_internal`) - hence, this
|
|
127
|
+
workaround.
|
|
128
|
+
"""
|
|
129
|
+
if key in ("_attn_implementation", "_attn_implementation_internal"):
|
|
130
|
+
if value and "flash" in value and value.removeprefix("paged|") != "kernels-community/vllm-flash-attn3":
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"GPT-OSS model does not support the specified flash attention implementation: {value}. "
|
|
133
|
+
"Only `kernels-community/vllm-flash-attn3` is supported."
|
|
134
|
+
)
|
|
135
|
+
super().__setattr__(key, value)
|
|
136
|
+
|
|
120
137
|
|
|
121
138
|
__all__ = ["GptOssConfig"]
|
|
@@ -28,8 +28,7 @@ from torch.nn import functional as F
|
|
|
28
28
|
from ... import initialization as init
|
|
29
29
|
from ...cache_utils import Cache, DynamicCache
|
|
30
30
|
from ...generation import GenerationMixin
|
|
31
|
-
from ...integrations import use_kernelized_func
|
|
32
|
-
from ...integrations.hub_kernels import use_kernel_forward_from_hub
|
|
31
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
33
32
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
34
33
|
from ...modeling_layers import (
|
|
35
34
|
GenericForSequenceClassification,
|
|
@@ -89,8 +88,8 @@ class GptOssExperts(nn.Module):
|
|
|
89
88
|
|
|
90
89
|
Args:
|
|
91
90
|
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
|
|
92
|
-
selected_experts (torch.Tensor): (batch_size *
|
|
93
|
-
routing_weights (torch.Tensor): (batch_size *
|
|
91
|
+
selected_experts (torch.Tensor): (batch_size * seq_len, top_k)
|
|
92
|
+
routing_weights (torch.Tensor): (batch_size * seq_len, top_k)
|
|
94
93
|
Returns:
|
|
95
94
|
torch.Tensor
|
|
96
95
|
"""
|
|
@@ -160,8 +159,8 @@ class GptOssTopKRouter(nn.Module):
|
|
|
160
159
|
|
|
161
160
|
def forward(self, hidden_states):
|
|
162
161
|
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
|
163
|
-
router_logits = F.linear(hidden_states, self.weight, self.bias) # (
|
|
164
|
-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (
|
|
162
|
+
router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts)
|
|
163
|
+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k)
|
|
165
164
|
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
|
|
166
165
|
router_scores = router_top_value
|
|
167
166
|
return router_logits, router_scores, router_indices
|
|
@@ -197,7 +196,7 @@ class GptOssRotaryEmbedding(nn.Module):
|
|
|
197
196
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
198
197
|
|
|
199
198
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
200
|
-
self.original_inv_freq =
|
|
199
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
201
200
|
|
|
202
201
|
@staticmethod
|
|
203
202
|
def compute_default_rope_parameters(
|
|
@@ -445,8 +444,6 @@ class GptOssPreTrainedModel(PreTrainedModel):
|
|
|
445
444
|
"attentions": GptOssAttention,
|
|
446
445
|
}
|
|
447
446
|
_keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
|
|
448
|
-
_supports_flash_attention = False
|
|
449
|
-
_supports_flex_attention = False
|
|
450
447
|
|
|
451
448
|
@torch.no_grad()
|
|
452
449
|
def _init_weights(self, module):
|
|
@@ -21,7 +21,7 @@ from torch.nn import functional as F
|
|
|
21
21
|
|
|
22
22
|
from ... import initialization as init
|
|
23
23
|
from ...cache_utils import Cache, DynamicCache
|
|
24
|
-
from ...integrations
|
|
24
|
+
from ...integrations import use_kernel_forward_from_hub
|
|
25
25
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
26
26
|
from ...modeling_outputs import (
|
|
27
27
|
MoeModelOutputWithPast,
|
|
@@ -86,8 +86,8 @@ class GptOssExperts(nn.Module):
|
|
|
86
86
|
|
|
87
87
|
Args:
|
|
88
88
|
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
|
|
89
|
-
selected_experts (torch.Tensor): (batch_size *
|
|
90
|
-
routing_weights (torch.Tensor): (batch_size *
|
|
89
|
+
selected_experts (torch.Tensor): (batch_size * seq_len, top_k)
|
|
90
|
+
routing_weights (torch.Tensor): (batch_size * seq_len, top_k)
|
|
91
91
|
Returns:
|
|
92
92
|
torch.Tensor
|
|
93
93
|
"""
|
|
@@ -157,8 +157,8 @@ class GptOssTopKRouter(nn.Module):
|
|
|
157
157
|
|
|
158
158
|
def forward(self, hidden_states):
|
|
159
159
|
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
|
160
|
-
router_logits = F.linear(hidden_states, self.weight, self.bias) # (
|
|
161
|
-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (
|
|
160
|
+
router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts)
|
|
161
|
+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k)
|
|
162
162
|
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
|
|
163
163
|
router_scores = router_top_value
|
|
164
164
|
return router_logits, router_scores, router_indices
|
|
@@ -354,8 +354,6 @@ class GptOssDecoderLayer(LlamaDecoderLayer):
|
|
|
354
354
|
class GptOssPreTrainedModel(LlamaPreTrainedModel):
|
|
355
355
|
_keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
|
|
356
356
|
_supports_sdpa = False
|
|
357
|
-
_supports_flash_attention = False
|
|
358
|
-
_supports_flex_attention = False
|
|
359
357
|
_can_record_outputs = {
|
|
360
358
|
"router_logits": OutputRecorder(GptOssTopKRouter, index=0),
|
|
361
359
|
"hidden_states": GptOssDecoderLayer,
|