transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,594 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
# Copyright 2025 Baidu and HuggingFace Inc. team. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import os.path
|
|
16
|
+
from functools import partial
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from shutil import SameFileError, copyfile
|
|
19
|
+
from typing import Any, Optional, Union
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
from huggingface_hub import is_offline_mode
|
|
24
|
+
from huggingface_hub.dataclasses import validate_typed_dict
|
|
25
|
+
from PIL import ImageDraw, ImageFont
|
|
26
|
+
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
|
|
27
|
+
|
|
28
|
+
from ...image_processing_utils import BatchFeature
|
|
29
|
+
from ...image_utils import (
|
|
30
|
+
OPENAI_CLIP_MEAN,
|
|
31
|
+
OPENAI_CLIP_STD,
|
|
32
|
+
ChannelDimension,
|
|
33
|
+
PILImageResampling,
|
|
34
|
+
SizeDict,
|
|
35
|
+
get_image_size,
|
|
36
|
+
validate_kwargs,
|
|
37
|
+
)
|
|
38
|
+
from ...processing_utils import Unpack, VideosKwargs
|
|
39
|
+
from ...utils import (
|
|
40
|
+
IMAGE_PROCESSOR_NAME,
|
|
41
|
+
PROCESSOR_NAME,
|
|
42
|
+
VIDEO_PROCESSOR_NAME,
|
|
43
|
+
TensorType,
|
|
44
|
+
add_start_docstrings,
|
|
45
|
+
logging,
|
|
46
|
+
safe_load_json_file,
|
|
47
|
+
)
|
|
48
|
+
from ...utils.hub import cached_file
|
|
49
|
+
from ...utils.import_utils import is_tracing, requires
|
|
50
|
+
from ...video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor
|
|
51
|
+
from ...video_utils import (
|
|
52
|
+
VideoInput,
|
|
53
|
+
VideoMetadata,
|
|
54
|
+
group_videos_by_shape,
|
|
55
|
+
infer_channel_dimension_format,
|
|
56
|
+
reorder_videos,
|
|
57
|
+
)
|
|
58
|
+
from .image_processing_ernie4_5_vl_moe import smart_resize
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
logger = logging.get_logger(__name__)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Ernie4_5_VL_MoeVideoProcessorInitKwargs(VideosKwargs, total=False):
|
|
65
|
+
patch_size: int
|
|
66
|
+
temporal_patch_size: int
|
|
67
|
+
merge_size: int
|
|
68
|
+
min_frames: int
|
|
69
|
+
max_frames: int
|
|
70
|
+
draw_on_frames: bool
|
|
71
|
+
font: str
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@add_start_docstrings(
|
|
75
|
+
"Constructs a fast Ernie 4.5 VL image processor that dynamically resizes videos based on the original videos.",
|
|
76
|
+
BASE_VIDEO_PROCESSOR_DOCSTRING,
|
|
77
|
+
"""
|
|
78
|
+
patch_size (`int`, *optional*, defaults to 14):
|
|
79
|
+
The spacial patch size of the vision encoder.
|
|
80
|
+
temporal_patch_size (`int`, *optional*, defaults to 2):
|
|
81
|
+
The temporal patch size of the vision encoder.
|
|
82
|
+
merge_size (`int`, *optional*, defaults to 2):
|
|
83
|
+
The merge size of the vision encoder to llm encoder.
|
|
84
|
+
min_frames (`int`, *optional*, defaults to 16):
|
|
85
|
+
The minimum number of frames that can be sampled.
|
|
86
|
+
max_frames (`int`, *optional*, defaults to 180):
|
|
87
|
+
The maximum number of frames that can be sampled.
|
|
88
|
+
draw_on_frames (`bool`, *optional*, defaults to `True`):
|
|
89
|
+
Whether to draw timestamps on each frame or not.
|
|
90
|
+
This does not work with `torch.compile` but resembles
|
|
91
|
+
the performance of the original model.
|
|
92
|
+
font (`str`, *optional*, defaults to "Roboto-Regular.ttf"):
|
|
93
|
+
The associated font name for drawing on frames.
|
|
94
|
+
Defaults to "Roboto-Regular.ttf" and is expected to be
|
|
95
|
+
saved along the processor as separate file.
|
|
96
|
+
""",
|
|
97
|
+
)
|
|
98
|
+
@requires(backends=("torchvision",))
|
|
99
|
+
class Ernie4_5_VL_MoeVideoProcessor(BaseVideoProcessor):
|
|
100
|
+
resample = PILImageResampling.BICUBIC
|
|
101
|
+
size = {"shortest_edge": 299 * 28 * 28, "longest_edge": 1196 * 28 * 28}
|
|
102
|
+
image_mean = OPENAI_CLIP_MEAN
|
|
103
|
+
image_std = OPENAI_CLIP_STD
|
|
104
|
+
do_resize = True
|
|
105
|
+
do_rescale = True
|
|
106
|
+
do_normalize = True
|
|
107
|
+
do_convert_rgb = True
|
|
108
|
+
patch_size = 14
|
|
109
|
+
temporal_patch_size = 2
|
|
110
|
+
merge_size = 2
|
|
111
|
+
min_frames = 16
|
|
112
|
+
max_frames = 180
|
|
113
|
+
do_sample_frames = True
|
|
114
|
+
draw_on_frames = True
|
|
115
|
+
font = "Roboto-Regular.ttf"
|
|
116
|
+
valid_kwargs = Ernie4_5_VL_MoeVideoProcessorInitKwargs
|
|
117
|
+
model_input_names = ["pixel_values_videos", "video_grid_thw"]
|
|
118
|
+
|
|
119
|
+
def __init__(self, **kwargs: Unpack[Ernie4_5_VL_MoeVideoProcessorInitKwargs]):
|
|
120
|
+
temporal_patch_size = kwargs.get("temporal_patch_size", 2)
|
|
121
|
+
if temporal_patch_size is None or temporal_patch_size != 2:
|
|
122
|
+
raise ValueError("`Ernie 4.5 VL` only supports a temporal patch size of 2")
|
|
123
|
+
|
|
124
|
+
size = kwargs.pop("size", None)
|
|
125
|
+
size = self.size if size is None else size
|
|
126
|
+
if "shortest_edge" not in size or "longest_edge" not in size:
|
|
127
|
+
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
|
128
|
+
|
|
129
|
+
super().__init__(size=size, **kwargs)
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def get_video_processor_dict(
|
|
133
|
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
|
134
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
135
|
+
"""Overriden to additionally load the font for drawing on frames."""
|
|
136
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
|
137
|
+
force_download = kwargs.pop("force_download", False)
|
|
138
|
+
proxies = kwargs.pop("proxies", None)
|
|
139
|
+
token = kwargs.pop("token", None)
|
|
140
|
+
local_files_only = kwargs.pop("local_files_only", False)
|
|
141
|
+
revision = kwargs.pop("revision", None)
|
|
142
|
+
subfolder = kwargs.pop("subfolder", "")
|
|
143
|
+
|
|
144
|
+
from_pipeline = kwargs.pop("_from_pipeline", None)
|
|
145
|
+
from_auto_class = kwargs.pop("_from_auto", False)
|
|
146
|
+
|
|
147
|
+
user_agent = {"file_type": "video processor", "from_auto_class": from_auto_class}
|
|
148
|
+
if from_pipeline is not None:
|
|
149
|
+
user_agent["using_pipeline"] = from_pipeline
|
|
150
|
+
|
|
151
|
+
if is_offline_mode() and not local_files_only:
|
|
152
|
+
logger.info("Offline mode: forcing local_files_only=True")
|
|
153
|
+
local_files_only = True
|
|
154
|
+
|
|
155
|
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
|
156
|
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
|
157
|
+
if os.path.isfile(pretrained_model_name_or_path):
|
|
158
|
+
resolved_video_processor_file = pretrained_model_name_or_path
|
|
159
|
+
resolved_processor_file = None
|
|
160
|
+
is_local = True
|
|
161
|
+
else:
|
|
162
|
+
video_processor_file = VIDEO_PROCESSOR_NAME
|
|
163
|
+
try:
|
|
164
|
+
# Try to load with a new config name first and if not successful try with the old file name
|
|
165
|
+
# NOTE: we save all processor configs as nested dict in PROCESSOR_NAME from v5, which is the standard
|
|
166
|
+
resolved_processor_file = cached_file(
|
|
167
|
+
pretrained_model_name_or_path,
|
|
168
|
+
filename=PROCESSOR_NAME,
|
|
169
|
+
cache_dir=cache_dir,
|
|
170
|
+
force_download=force_download,
|
|
171
|
+
proxies=proxies,
|
|
172
|
+
local_files_only=local_files_only,
|
|
173
|
+
token=token,
|
|
174
|
+
user_agent=user_agent,
|
|
175
|
+
revision=revision,
|
|
176
|
+
subfolder=subfolder,
|
|
177
|
+
_raise_exceptions_for_missing_entries=False,
|
|
178
|
+
)
|
|
179
|
+
resolved_video_processor_files = [
|
|
180
|
+
resolved_file
|
|
181
|
+
for filename in [video_processor_file, IMAGE_PROCESSOR_NAME]
|
|
182
|
+
if (
|
|
183
|
+
resolved_file := cached_file(
|
|
184
|
+
pretrained_model_name_or_path,
|
|
185
|
+
filename=filename,
|
|
186
|
+
cache_dir=cache_dir,
|
|
187
|
+
force_download=force_download,
|
|
188
|
+
proxies=proxies,
|
|
189
|
+
local_files_only=local_files_only,
|
|
190
|
+
token=token,
|
|
191
|
+
user_agent=user_agent,
|
|
192
|
+
revision=revision,
|
|
193
|
+
subfolder=subfolder,
|
|
194
|
+
_raise_exceptions_for_missing_entries=False,
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
is not None
|
|
198
|
+
]
|
|
199
|
+
resolved_video_processor_file = (
|
|
200
|
+
resolved_video_processor_files[0] if resolved_video_processor_files else None
|
|
201
|
+
)
|
|
202
|
+
except OSError:
|
|
203
|
+
# Raise any OS error raise by `cached_file`. It will have a helpful error message adapted to
|
|
204
|
+
# the original exception.
|
|
205
|
+
raise
|
|
206
|
+
except Exception:
|
|
207
|
+
# For any other exception, we throw a generic error.
|
|
208
|
+
raise OSError(
|
|
209
|
+
f"Can't load video processor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
|
210
|
+
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
|
211
|
+
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
|
212
|
+
f" directory containing a {video_processor_file} file"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Load video_processor dict. Priority goes as (nested config if found -> video processor config -> image processor config)
|
|
216
|
+
# We are downloading both configs because almost all models have a `processor_config.json` but
|
|
217
|
+
# not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
|
|
218
|
+
video_processor_dict = None
|
|
219
|
+
if resolved_processor_file is not None:
|
|
220
|
+
processor_dict = safe_load_json_file(resolved_processor_file)
|
|
221
|
+
if "video_processor" in processor_dict:
|
|
222
|
+
video_processor_dict = processor_dict["video_processor"]
|
|
223
|
+
|
|
224
|
+
if resolved_video_processor_file is not None and video_processor_dict is None:
|
|
225
|
+
video_processor_dict = safe_load_json_file(resolved_video_processor_file)
|
|
226
|
+
|
|
227
|
+
if video_processor_dict is None:
|
|
228
|
+
raise OSError(
|
|
229
|
+
f"Can't load video processor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
|
230
|
+
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
|
231
|
+
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
|
232
|
+
f" directory containing a {video_processor_file} file"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Specific to Ernie 4.5 VL Moe, we load the font file along the json (if we draw on frames)
|
|
236
|
+
draws_on_frames = video_processor_dict.get("draw_on_frames")
|
|
237
|
+
if (font_name := video_processor_dict.get("font")) is None and draws_on_frames:
|
|
238
|
+
raise AttributeError(
|
|
239
|
+
"Expected a `font` to be saved when using `draw_on_frames` in Ernie 4.5 VL Moe; found nothing."
|
|
240
|
+
)
|
|
241
|
+
if font_name is not None and draws_on_frames:
|
|
242
|
+
video_processor_dict["font"] = cached_file(
|
|
243
|
+
pretrained_model_name_or_path,
|
|
244
|
+
filename=font_name,
|
|
245
|
+
cache_dir=cache_dir,
|
|
246
|
+
force_download=force_download,
|
|
247
|
+
proxies=proxies,
|
|
248
|
+
local_files_only=local_files_only,
|
|
249
|
+
token=token,
|
|
250
|
+
user_agent=user_agent,
|
|
251
|
+
revision=revision,
|
|
252
|
+
subfolder=subfolder,
|
|
253
|
+
_raise_exceptions_for_missing_entries=False,
|
|
254
|
+
)
|
|
255
|
+
try:
|
|
256
|
+
ImageFont.truetype(video_processor_dict["font"])
|
|
257
|
+
except (TypeError, OSError):
|
|
258
|
+
raise OSError(
|
|
259
|
+
f"Could not find an associated font file for {video_processor_dict['font']}. "
|
|
260
|
+
"Make sure to save a font file along for Ernie 4.5 VL Moe."
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
if is_local:
|
|
264
|
+
logger.info(f"loading configuration file {resolved_video_processor_file}")
|
|
265
|
+
else:
|
|
266
|
+
logger.info(
|
|
267
|
+
f"loading configuration file {video_processor_file} from cache at {resolved_video_processor_file}"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
return video_processor_dict, kwargs
|
|
271
|
+
|
|
272
|
+
def to_dict(self) -> dict[str, Any]:
|
|
273
|
+
"""Overriden to strip the prefix of the full path for the font, e.g. `tmp/folder/font.tff` -> `font.tff`"""
|
|
274
|
+
output = super().to_dict()
|
|
275
|
+
|
|
276
|
+
if os.path.isfile(output.get("font")):
|
|
277
|
+
output["font"] = Path(output["font"]).name
|
|
278
|
+
elif output.get("draw_on_frames"):
|
|
279
|
+
raise ValueError(
|
|
280
|
+
f"The video processor dict contains an invalid path to its font: {output['font']}. "
|
|
281
|
+
"Please make sure to contain a valid path or disable `draw_on_frames`."
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
return output
|
|
285
|
+
|
|
286
|
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
|
287
|
+
"""We additionally save a copy of the font to the `save_directory` (if we found a file there)"""
|
|
288
|
+
os.makedirs(save_directory, exist_ok=True)
|
|
289
|
+
|
|
290
|
+
if os.path.isfile(self.font):
|
|
291
|
+
try:
|
|
292
|
+
copyfile(self.font, Path(save_directory, Path(self.font).name))
|
|
293
|
+
except SameFileError: # already exists which we allow (copy if needed)
|
|
294
|
+
pass
|
|
295
|
+
|
|
296
|
+
return super().save_pretrained(save_directory, push_to_hub, **kwargs)
|
|
297
|
+
|
|
298
|
+
def _further_process_kwargs(
|
|
299
|
+
self,
|
|
300
|
+
size: Optional[SizeDict] = None,
|
|
301
|
+
**kwargs,
|
|
302
|
+
) -> dict:
|
|
303
|
+
"""
|
|
304
|
+
Update kwargs that need further processing before being validated
|
|
305
|
+
Can be overridden by subclasses to customize the processing of kwargs.
|
|
306
|
+
"""
|
|
307
|
+
if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
|
|
308
|
+
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
|
309
|
+
|
|
310
|
+
return super()._further_process_kwargs(size=size, **kwargs)
|
|
311
|
+
|
|
312
|
+
def sample_frames(
|
|
313
|
+
self,
|
|
314
|
+
metadata: VideoMetadata,
|
|
315
|
+
min_frames: Optional[int] = None,
|
|
316
|
+
max_frames: Optional[int] = None,
|
|
317
|
+
num_frames: Optional[int] = None,
|
|
318
|
+
fps: Optional[Union[int, float]] = None,
|
|
319
|
+
**kwargs,
|
|
320
|
+
):
|
|
321
|
+
if fps is not None and num_frames is not None:
|
|
322
|
+
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
|
|
323
|
+
|
|
324
|
+
num_frames = num_frames if num_frames is not None else self.num_frames
|
|
325
|
+
min_frames = min_frames if min_frames is not None else self.min_frames
|
|
326
|
+
max_frames = max_frames if max_frames is not None else self.max_frames
|
|
327
|
+
total_num_frames = metadata.total_num_frames
|
|
328
|
+
|
|
329
|
+
if num_frames is not None:
|
|
330
|
+
if num_frames < min_frames or num_frames > max_frames:
|
|
331
|
+
raise ValueError(f"`num_frames` must be {min_frames} <= x <= {max_frames}. Got {num_frames} instead.")
|
|
332
|
+
else:
|
|
333
|
+
if fps is not None and (metadata is None or metadata.fps is None):
|
|
334
|
+
raise ValueError(
|
|
335
|
+
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
|
|
336
|
+
"Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
|
|
337
|
+
)
|
|
338
|
+
num_frames = total_num_frames / metadata.fps * fps if fps is not None else total_num_frames
|
|
339
|
+
num_frames = min(max(num_frames, min_frames), max_frames, total_num_frames)
|
|
340
|
+
|
|
341
|
+
if num_frames > total_num_frames:
|
|
342
|
+
raise ValueError(
|
|
343
|
+
f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
|
|
344
|
+
"Decrease `num_frames` or `fps` for sampling."
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
indices = torch.arange(0, total_num_frames, total_num_frames / num_frames).int()
|
|
348
|
+
|
|
349
|
+
return indices
|
|
350
|
+
|
|
351
|
+
def _convert_timestamp(self, time_stamp_in_seconds):
|
|
352
|
+
"""Convert to `time: hr:min:sec` format"""
|
|
353
|
+
hours = time_stamp_in_seconds // 3600
|
|
354
|
+
time_stamp_in_seconds = time_stamp_in_seconds % 3600
|
|
355
|
+
mins = time_stamp_in_seconds // 60
|
|
356
|
+
time_stamp_in_seconds = time_stamp_in_seconds % 60
|
|
357
|
+
return f"time: {int(hours):02d}:{int(mins):02d}:{time_stamp_in_seconds:05.02f}"
|
|
358
|
+
|
|
359
|
+
def _render_image_with_timestamp(self, image: torch.Tensor, timestamp: str, size_factor: float = 0.1):
|
|
360
|
+
"""Draws a black timestamp with a white border on the corner of the frame"""
|
|
361
|
+
if self.font is None:
|
|
362
|
+
raise AttributeError("To draw on frames with Ernie 4.5 VL, you need an associated font; found nothing")
|
|
363
|
+
|
|
364
|
+
# FIXME: conversion `torch->PIL->torch` is inefficient ~6ms per frame
|
|
365
|
+
# Left for optimization if anyone want to pick it up
|
|
366
|
+
#
|
|
367
|
+
# This can take up to ~1s in preprocessing (if default sampling is used):
|
|
368
|
+
# 180 (frames) x 6ms = 1080ms = ~1,1s
|
|
369
|
+
image = to_pil_image(image)
|
|
370
|
+
|
|
371
|
+
font_size = int(min(*image.size) * size_factor)
|
|
372
|
+
outline_size = int(font_size * size_factor)
|
|
373
|
+
font = ImageFont.truetype(self.font, font_size)
|
|
374
|
+
|
|
375
|
+
# Draw a black text with a white border
|
|
376
|
+
draw = ImageDraw.Draw(image)
|
|
377
|
+
draw.text(
|
|
378
|
+
(0, 0),
|
|
379
|
+
timestamp,
|
|
380
|
+
font=font,
|
|
381
|
+
fill=(0, 0, 0),
|
|
382
|
+
stroke_width=outline_size,
|
|
383
|
+
stroke_fill=(255, 255, 255),
|
|
384
|
+
)
|
|
385
|
+
return pil_to_tensor(image)
|
|
386
|
+
|
|
387
|
+
def _prepare_input_videos(
|
|
388
|
+
self,
|
|
389
|
+
videos: VideoInput,
|
|
390
|
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
|
391
|
+
device: Optional[str] = None,
|
|
392
|
+
video_metadata: Optional[list[VideoMetadata]] = None,
|
|
393
|
+
draw_on_frames: bool = True,
|
|
394
|
+
) -> list["torch.Tensor"]:
|
|
395
|
+
"""
|
|
396
|
+
Prepare the input videos for processing.
|
|
397
|
+
"""
|
|
398
|
+
processed_videos = []
|
|
399
|
+
for video, metadata in zip(videos, video_metadata):
|
|
400
|
+
# Check for attributes that are necessary to draw timestamps on frames
|
|
401
|
+
if draw_on_frames:
|
|
402
|
+
if metadata is None:
|
|
403
|
+
raise ValueError("Need video metadata to process videos in Ernie 4.5 VL using `draw_on_frames`")
|
|
404
|
+
elif metadata.fps is None:
|
|
405
|
+
metadata.fps = 24
|
|
406
|
+
logger.warning_once(
|
|
407
|
+
"Could not infer the fps of a video due to the metadata not being available, "
|
|
408
|
+
"defaulting to `24`. Please provide `video_metadata` for more accurate results."
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# `make_batched_videos` always returns a 4D array per video
|
|
412
|
+
if isinstance(video, np.ndarray):
|
|
413
|
+
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
|
|
414
|
+
video = torch.from_numpy(video).contiguous()
|
|
415
|
+
|
|
416
|
+
# Infer the channel dimension format if not provided
|
|
417
|
+
if input_data_format is None:
|
|
418
|
+
input_data_format = infer_channel_dimension_format(video)
|
|
419
|
+
|
|
420
|
+
if input_data_format == ChannelDimension.LAST:
|
|
421
|
+
video = video.permute(0, 3, 1, 2).contiguous()
|
|
422
|
+
|
|
423
|
+
# specific to ernie, draws timestamps on each frame (if enabled)
|
|
424
|
+
if draw_on_frames:
|
|
425
|
+
if is_tracing(video):
|
|
426
|
+
raise RuntimeError(
|
|
427
|
+
"Using `torch.compile` is not compatible with drawing on frames. "
|
|
428
|
+
"Either don't use `torch.compile` or don't draw on frames via the kwarg `draw_on_frames=False`."
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
for idx, frame in enumerate(video):
|
|
432
|
+
video[idx] = self._render_image_with_timestamp(
|
|
433
|
+
frame, self._convert_timestamp(metadata.timestamps[idx])
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# last frame is copied if uneven (mitigating issues for temporal patch size)
|
|
437
|
+
if video.shape[0] % 2 != 0:
|
|
438
|
+
video = torch.cat((video, video[-1].detach().clone()[None, ...]), dim=0)
|
|
439
|
+
|
|
440
|
+
if device is not None:
|
|
441
|
+
video = video.to(device)
|
|
442
|
+
|
|
443
|
+
processed_videos.append(video)
|
|
444
|
+
return processed_videos
|
|
445
|
+
|
|
446
|
+
def _preprocess(
|
|
447
|
+
self,
|
|
448
|
+
videos: list[torch.Tensor],
|
|
449
|
+
do_convert_rgb: bool = True,
|
|
450
|
+
do_resize: bool = True,
|
|
451
|
+
size: Optional[SizeDict] = None,
|
|
452
|
+
interpolation: PILImageResampling = PILImageResampling.BICUBIC,
|
|
453
|
+
do_rescale: bool = True,
|
|
454
|
+
rescale_factor: float = 1 / 255.0,
|
|
455
|
+
do_normalize: bool = True,
|
|
456
|
+
image_mean: Optional[Union[float, list[float]]] = None,
|
|
457
|
+
image_std: Optional[Union[float, list[float]]] = None,
|
|
458
|
+
patch_size: Optional[int] = None,
|
|
459
|
+
merge_size: Optional[int] = None,
|
|
460
|
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
461
|
+
**kwargs,
|
|
462
|
+
):
|
|
463
|
+
# Group videos by size for batched resizing
|
|
464
|
+
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
|
465
|
+
resized_videos_grouped = {}
|
|
466
|
+
for shape, stacked_videos in grouped_videos.items():
|
|
467
|
+
if do_convert_rgb:
|
|
468
|
+
stacked_videos = self.convert_to_rgb(stacked_videos)
|
|
469
|
+
|
|
470
|
+
height, width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST)
|
|
471
|
+
resized_height, resized_width = height, width
|
|
472
|
+
if do_resize:
|
|
473
|
+
resized_height, resized_width = smart_resize(
|
|
474
|
+
height,
|
|
475
|
+
width,
|
|
476
|
+
factor=patch_size * merge_size,
|
|
477
|
+
min_pixels=size["shortest_edge"],
|
|
478
|
+
max_pixels=size["longest_edge"],
|
|
479
|
+
)
|
|
480
|
+
stacked_videos = self.resize(
|
|
481
|
+
image=stacked_videos,
|
|
482
|
+
size=SizeDict(height=resized_height, width=resized_width),
|
|
483
|
+
interpolation=interpolation,
|
|
484
|
+
)
|
|
485
|
+
resized_videos_grouped[shape] = stacked_videos
|
|
486
|
+
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
|
|
487
|
+
|
|
488
|
+
# Group videos by size for further processing
|
|
489
|
+
# Needed in case do_resize is False, or resize returns videos with different sizes
|
|
490
|
+
grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
|
|
491
|
+
processed_videos_grouped = {}
|
|
492
|
+
processed_grids = {}
|
|
493
|
+
for shape, stacked_videos in grouped_videos.items():
|
|
494
|
+
resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST)
|
|
495
|
+
|
|
496
|
+
# Fused rescale and normalize
|
|
497
|
+
stacked_videos = self.rescale_and_normalize(
|
|
498
|
+
stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
|
499
|
+
)
|
|
500
|
+
patches = stacked_videos
|
|
501
|
+
|
|
502
|
+
batch_size, grid_t, channel = patches.shape[:3]
|
|
503
|
+
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
|
504
|
+
|
|
505
|
+
patches = patches.view(
|
|
506
|
+
batch_size,
|
|
507
|
+
grid_t,
|
|
508
|
+
channel,
|
|
509
|
+
grid_h // merge_size,
|
|
510
|
+
merge_size,
|
|
511
|
+
patch_size,
|
|
512
|
+
grid_w // merge_size,
|
|
513
|
+
merge_size,
|
|
514
|
+
patch_size,
|
|
515
|
+
)
|
|
516
|
+
# Reorder dimensions to group grid and patch information for subsequent flattening.
|
|
517
|
+
# [batch, grid_t, grid_h/merge, grid_w/merge, merge, merge, channel, patch, patch]
|
|
518
|
+
patches = patches.permute(0, 1, 3, 6, 4, 7, 2, 5, 8)
|
|
519
|
+
|
|
520
|
+
flatten_patches = patches.reshape(
|
|
521
|
+
batch_size,
|
|
522
|
+
grid_t * grid_h * grid_w,
|
|
523
|
+
channel * patch_size * patch_size,
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
processed_videos_grouped[shape] = flatten_patches
|
|
527
|
+
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
|
|
528
|
+
|
|
529
|
+
processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
|
|
530
|
+
processed_grids = reorder_videos(processed_grids, grouped_videos_index)
|
|
531
|
+
pixel_values_videos = torch.cat(processed_videos, dim=0)
|
|
532
|
+
video_grid_thw = torch.tensor(processed_grids)
|
|
533
|
+
|
|
534
|
+
return BatchFeature(
|
|
535
|
+
data={"pixel_values_videos": pixel_values_videos, "video_grid_thw": video_grid_thw},
|
|
536
|
+
tensor_type=return_tensors,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
@add_start_docstrings(
|
|
540
|
+
BASE_VIDEO_PROCESSOR_DOCSTRING,
|
|
541
|
+
)
|
|
542
|
+
def preprocess(
|
|
543
|
+
self,
|
|
544
|
+
videos: VideoInput,
|
|
545
|
+
**kwargs: Unpack[VideosKwargs],
|
|
546
|
+
) -> BatchFeature:
|
|
547
|
+
validate_kwargs(
|
|
548
|
+
captured_kwargs=kwargs.keys(),
|
|
549
|
+
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
# Perform type validation on received kwargs
|
|
553
|
+
validate_typed_dict(self.valid_kwargs, kwargs)
|
|
554
|
+
|
|
555
|
+
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
|
556
|
+
# by the user, it gets its default value from the instance, or is set to None.
|
|
557
|
+
for kwarg_name in self.valid_kwargs.__annotations__:
|
|
558
|
+
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
|
559
|
+
|
|
560
|
+
input_data_format = kwargs.pop("input_data_format")
|
|
561
|
+
do_sample_frames = kwargs.pop("do_sample_frames")
|
|
562
|
+
device = kwargs.pop("device")
|
|
563
|
+
video_metadata = kwargs.pop("video_metadata")
|
|
564
|
+
draw_on_frames = kwargs.pop("draw_on_frames")
|
|
565
|
+
|
|
566
|
+
sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
|
|
567
|
+
videos, video_metadata = self._decode_and_sample_videos(
|
|
568
|
+
videos,
|
|
569
|
+
video_metadata=video_metadata,
|
|
570
|
+
do_sample_frames=do_sample_frames,
|
|
571
|
+
sample_indices_fn=sample_indices_fn,
|
|
572
|
+
)
|
|
573
|
+
videos = self._prepare_input_videos(
|
|
574
|
+
videos=videos,
|
|
575
|
+
input_data_format=input_data_format,
|
|
576
|
+
device=device,
|
|
577
|
+
video_metadata=video_metadata,
|
|
578
|
+
draw_on_frames=draw_on_frames,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
kwargs = self._further_process_kwargs(**kwargs)
|
|
582
|
+
self._validate_preprocess_kwargs(**kwargs)
|
|
583
|
+
|
|
584
|
+
# Pop kwargs that are not needed in _preprocess
|
|
585
|
+
kwargs.pop("data_format")
|
|
586
|
+
return_metadata = kwargs.pop("return_metadata")
|
|
587
|
+
|
|
588
|
+
preprocessed_videos = self._preprocess(videos=videos, **kwargs)
|
|
589
|
+
if return_metadata:
|
|
590
|
+
preprocessed_videos["video_metadata"] = video_metadata
|
|
591
|
+
return preprocessed_videos
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
__all__ = ["Ernie4_5_VL_MoeVideoProcessor"]
|
|
@@ -90,6 +90,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
|
90
90
|
|
|
91
91
|
def __init__(self, dim: int):
|
|
92
92
|
super().__init__()
|
|
93
|
+
self.dim = dim
|
|
93
94
|
# Generate and save the inverse frequency buffer (non trainable)
|
|
94
95
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
|
95
96
|
self.register_buffer("inv_freq", inv_freq)
|
|
@@ -558,6 +559,11 @@ class EsmPreTrainedModel(PreTrainedModel):
|
|
|
558
559
|
super()._init_weights(module)
|
|
559
560
|
if isinstance(module, EsmLMHead):
|
|
560
561
|
init.zeros_(module.bias)
|
|
562
|
+
elif isinstance(module, EsmEmbeddings):
|
|
563
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
564
|
+
elif isinstance(module, RotaryEmbedding):
|
|
565
|
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
|
|
566
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
561
567
|
|
|
562
568
|
def get_output_embeddings(self):
|
|
563
569
|
# NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
|
|
@@ -912,7 +912,7 @@ class EsmFoldPreTrainedModel(EsmPreTrainedModel):
|
|
|
912
912
|
elif module.init == "gating":
|
|
913
913
|
init.zeros_(module.weight)
|
|
914
914
|
if module.bias:
|
|
915
|
-
init.
|
|
915
|
+
init.ones_(module.bias)
|
|
916
916
|
elif module.init == "normal":
|
|
917
917
|
init.kaiming_normal_(module.weight, nonlinearity="linear")
|
|
918
918
|
elif module.init == "final":
|
|
@@ -1979,6 +1979,11 @@ class EsmForProteinFolding(EsmPreTrainedModel):
|
|
|
1979
1979
|
|
|
1980
1980
|
_can_record_outputs = None
|
|
1981
1981
|
|
|
1982
|
+
def _init_weights(self, module):
|
|
1983
|
+
super()._init_weights(module)
|
|
1984
|
+
if isinstance(module, EsmForProteinFolding):
|
|
1985
|
+
init.copy_(module.af2_to_esm, module._af2_to_esm_from_vocab_list(module.config.vocab_list))
|
|
1986
|
+
|
|
1982
1987
|
def __init__(self, config):
|
|
1983
1988
|
super().__init__(config)
|
|
1984
1989
|
|
|
@@ -185,6 +185,7 @@ class EvollaSaProtRotaryEmbedding(nn.Module):
|
|
|
185
185
|
|
|
186
186
|
def __init__(self, dim: int):
|
|
187
187
|
super().__init__()
|
|
188
|
+
self.dim = dim
|
|
188
189
|
# Generate and save the inverse frequency buffer (non trainable)
|
|
189
190
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
|
190
191
|
self.register_buffer("inv_freq", inv_freq)
|
|
@@ -518,12 +519,19 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel):
|
|
|
518
519
|
],
|
|
519
520
|
}
|
|
520
521
|
|
|
522
|
+
def _init_weights(self, module):
|
|
523
|
+
super()._init_weights(module)
|
|
524
|
+
if isinstance(module, EvollaSaProtRotaryEmbedding):
|
|
525
|
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
|
|
526
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
527
|
+
|
|
521
528
|
|
|
522
529
|
class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
|
|
523
530
|
def __init__(self, config: SaProtConfig):
|
|
524
531
|
super().__init__(config)
|
|
525
532
|
self.embeddings = EvollaSaProtEmbeddings(config)
|
|
526
533
|
self.encoder = EvollaSaProtEncoder(config)
|
|
534
|
+
self.post_init()
|
|
527
535
|
|
|
528
536
|
def get_input_embeddings(self):
|
|
529
537
|
return self.embeddings.word_embeddings
|
|
@@ -980,7 +988,7 @@ class EvollaRotaryEmbedding(nn.Module):
|
|
|
980
988
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
981
989
|
|
|
982
990
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
983
|
-
self.original_inv_freq =
|
|
991
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
984
992
|
|
|
985
993
|
@staticmethod
|
|
986
994
|
def compute_default_rope_parameters(
|