transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -347,6 +347,22 @@ class FuyuProcessor(ProcessorMixin):
|
|
|
347
347
|
The tokenizer is a required input.
|
|
348
348
|
"""
|
|
349
349
|
|
|
350
|
+
@classmethod
|
|
351
|
+
def _load_tokenizer_from_pretrained(
|
|
352
|
+
cls, sub_processor_type, pretrained_model_name_or_path, subfolder="", **kwargs
|
|
353
|
+
):
|
|
354
|
+
"""
|
|
355
|
+
Override for BC. Fuyu uses TokenizersBackend and requires token_type_ids to be removed from model_input_names
|
|
356
|
+
because Fuyu uses mm_token_type_ids instead for multimodal token identification. `
|
|
357
|
+
"""
|
|
358
|
+
from ...tokenization_utils_tokenizers import TokenizersBackend
|
|
359
|
+
|
|
360
|
+
tokenizer = TokenizersBackend.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
361
|
+
# Remove token_type_ids as Fuyu uses mm_token_type_ids instead
|
|
362
|
+
if "token_type_ids" in tokenizer.model_input_names:
|
|
363
|
+
tokenizer.model_input_names.remove("token_type_ids")
|
|
364
|
+
return tokenizer
|
|
365
|
+
|
|
350
366
|
def __init__(self, image_processor, tokenizer, **kwargs):
|
|
351
367
|
super().__init__(image_processor=image_processor, tokenizer=tokenizer)
|
|
352
368
|
self.image_processor = image_processor
|
|
@@ -98,7 +98,7 @@ class GemmaRotaryEmbedding(nn.Module):
|
|
|
98
98
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
99
99
|
|
|
100
100
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
101
|
-
self.original_inv_freq =
|
|
101
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
102
102
|
|
|
103
103
|
@staticmethod
|
|
104
104
|
def compute_default_rope_parameters(
|
|
@@ -410,16 +410,14 @@ class GemmaModel(GemmaPreTrainedModel):
|
|
|
410
410
|
if position_ids is None:
|
|
411
411
|
position_ids = cache_position.unsqueeze(0)
|
|
412
412
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
position_ids=position_ids,
|
|
422
|
-
)
|
|
413
|
+
causal_mask = create_causal_mask(
|
|
414
|
+
config=self.config,
|
|
415
|
+
input_embeds=inputs_embeds,
|
|
416
|
+
attention_mask=attention_mask,
|
|
417
|
+
cache_position=cache_position,
|
|
418
|
+
past_key_values=past_key_values,
|
|
419
|
+
position_ids=position_ids,
|
|
420
|
+
)
|
|
423
421
|
|
|
424
422
|
# embed positions
|
|
425
423
|
hidden_states = inputs_embeds
|
|
@@ -434,7 +432,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|
|
434
432
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
435
433
|
hidden_states = decoder_layer(
|
|
436
434
|
hidden_states,
|
|
437
|
-
attention_mask=
|
|
435
|
+
attention_mask=causal_mask,
|
|
438
436
|
position_ids=position_ids,
|
|
439
437
|
past_key_values=past_key_values,
|
|
440
438
|
use_cache=use_cache,
|
|
@@ -267,16 +267,14 @@ class GemmaModel(LlamaModel):
|
|
|
267
267
|
if position_ids is None:
|
|
268
268
|
position_ids = cache_position.unsqueeze(0)
|
|
269
269
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
position_ids=position_ids,
|
|
279
|
-
)
|
|
270
|
+
causal_mask = create_causal_mask(
|
|
271
|
+
config=self.config,
|
|
272
|
+
input_embeds=inputs_embeds,
|
|
273
|
+
attention_mask=attention_mask,
|
|
274
|
+
cache_position=cache_position,
|
|
275
|
+
past_key_values=past_key_values,
|
|
276
|
+
position_ids=position_ids,
|
|
277
|
+
)
|
|
280
278
|
|
|
281
279
|
# embed positions
|
|
282
280
|
hidden_states = inputs_embeds
|
|
@@ -291,7 +289,7 @@ class GemmaModel(LlamaModel):
|
|
|
291
289
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
292
290
|
hidden_states = decoder_layer(
|
|
293
291
|
hidden_states,
|
|
294
|
-
attention_mask=
|
|
292
|
+
attention_mask=causal_mask,
|
|
295
293
|
position_ids=position_ids,
|
|
296
294
|
past_key_values=past_key_values,
|
|
297
295
|
use_cache=use_cache,
|
|
@@ -99,7 +99,7 @@ class Gemma2RotaryEmbedding(nn.Module):
|
|
|
99
99
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
100
100
|
|
|
101
101
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
102
|
-
self.original_inv_freq =
|
|
102
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
103
103
|
|
|
104
104
|
@staticmethod
|
|
105
105
|
def compute_default_rope_parameters(
|
|
@@ -244,7 +244,7 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
|
|
|
244
244
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
245
245
|
|
|
246
246
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
247
|
-
self.original_inv_freq =
|
|
247
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
248
248
|
|
|
249
249
|
@torch.no_grad()
|
|
250
250
|
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
|
@@ -231,7 +231,6 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
|
|
|
231
231
|
processed_images_grouped[shape] = stacked_images
|
|
232
232
|
|
|
233
233
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
234
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
235
234
|
return BatchFeature(
|
|
236
235
|
data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors
|
|
237
236
|
)
|
|
@@ -100,6 +100,7 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding):
|
|
|
100
100
|
|
|
101
101
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
|
102
102
|
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
|
103
|
+
self.scalar_embed_scale = embed_scale
|
|
103
104
|
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
|
104
105
|
|
|
105
106
|
def forward(self, input_ids: torch.Tensor):
|
|
@@ -165,7 +166,7 @@ class Gemma3RotaryEmbedding(nn.Module):
|
|
|
165
166
|
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
|
|
166
167
|
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
|
|
167
168
|
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
|
|
168
|
-
|
|
169
|
+
self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
|
|
169
170
|
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
|
|
170
171
|
|
|
171
172
|
@staticmethod
|
|
@@ -468,6 +469,16 @@ class Gemma3PreTrainedModel(PreTrainedModel):
|
|
|
468
469
|
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
|
469
470
|
elif "RMSNorm" in module.__class__.__name__:
|
|
470
471
|
init.zeros_(module.weight)
|
|
472
|
+
elif isinstance(module, Gemma3TextScaledWordEmbedding):
|
|
473
|
+
init.constant_(module.embed_scale, module.scalar_embed_scale)
|
|
474
|
+
elif isinstance(module, Gemma3RotaryEmbedding):
|
|
475
|
+
for layer_type in module.layer_types:
|
|
476
|
+
rope_init_fn = module.compute_default_rope_parameters
|
|
477
|
+
if module.rope_type[layer_type] != "default":
|
|
478
|
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
|
479
|
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
480
|
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
481
|
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
471
482
|
|
|
472
483
|
|
|
473
484
|
def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
|
|
@@ -754,6 +765,7 @@ def create_causal_mask_mapping(
|
|
|
754
765
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
755
766
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
756
767
|
is_training: bool = False,
|
|
768
|
+
is_first_iteration: Optional[bool] = None,
|
|
757
769
|
**kwargs,
|
|
758
770
|
) -> dict:
|
|
759
771
|
"""
|
|
@@ -776,8 +788,12 @@ def create_causal_mask_mapping(
|
|
|
776
788
|
# NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
|
|
777
789
|
# (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
|
|
778
790
|
# means). Determining prefill in that case requires checking data values, which is not compile-compatible.
|
|
779
|
-
|
|
780
|
-
|
|
791
|
+
is_first_iteration = (
|
|
792
|
+
is_first_iteration
|
|
793
|
+
if is_first_iteration is not None
|
|
794
|
+
else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
|
|
795
|
+
)
|
|
796
|
+
if token_type_ids is not None and is_first_iteration:
|
|
781
797
|
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
|
|
782
798
|
# undo the causal masking)
|
|
783
799
|
|
|
@@ -1123,6 +1139,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1123
1139
|
use_cache=True,
|
|
1124
1140
|
logits_to_keep=None,
|
|
1125
1141
|
labels=None,
|
|
1142
|
+
is_first_iteration=False,
|
|
1126
1143
|
**kwargs,
|
|
1127
1144
|
):
|
|
1128
1145
|
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
|
@@ -1136,12 +1153,15 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1136
1153
|
use_cache=use_cache,
|
|
1137
1154
|
logits_to_keep=logits_to_keep,
|
|
1138
1155
|
token_type_ids=token_type_ids,
|
|
1156
|
+
is_first_iteration=is_first_iteration,
|
|
1139
1157
|
**kwargs,
|
|
1140
1158
|
)
|
|
1141
1159
|
|
|
1142
|
-
#
|
|
1143
|
-
#
|
|
1144
|
-
|
|
1160
|
+
# Pixel values are used only in the first iteration if available
|
|
1161
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
1162
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
1163
|
+
# iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
|
|
1164
|
+
if is_first_iteration or not use_cache:
|
|
1145
1165
|
model_inputs["pixel_values"] = pixel_values
|
|
1146
1166
|
|
|
1147
1167
|
return model_inputs
|
|
@@ -1155,6 +1175,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1155
1175
|
past_key_values: Optional[Cache],
|
|
1156
1176
|
position_ids: Optional[torch.Tensor],
|
|
1157
1177
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
1178
|
+
is_first_iteration: Optional[bool] = False,
|
|
1158
1179
|
**kwargs,
|
|
1159
1180
|
) -> dict:
|
|
1160
1181
|
# Uses the overwritten `create_masks_for_generate` with `token_type_ids` masking
|
|
@@ -1166,7 +1187,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1166
1187
|
past_key_values,
|
|
1167
1188
|
position_ids,
|
|
1168
1189
|
token_type_ids,
|
|
1169
|
-
|
|
1190
|
+
is_first_iteration=is_first_iteration,
|
|
1170
1191
|
**{k: v for k, v in kwargs.items() if k != "pixel_values"},
|
|
1171
1192
|
)
|
|
1172
1193
|
|
|
@@ -352,6 +352,7 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding):
|
|
|
352
352
|
|
|
353
353
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
|
354
354
|
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
|
355
|
+
self.scalar_embed_scale = embed_scale
|
|
355
356
|
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
|
356
357
|
|
|
357
358
|
def forward(self, input_ids: torch.Tensor):
|
|
@@ -389,7 +390,7 @@ class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding):
|
|
|
389
390
|
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
|
|
390
391
|
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
|
|
391
392
|
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
|
|
392
|
-
|
|
393
|
+
self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
|
|
393
394
|
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
|
|
394
395
|
|
|
395
396
|
@staticmethod
|
|
@@ -576,6 +577,16 @@ class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
|
|
|
576
577
|
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
|
577
578
|
elif "RMSNorm" in module.__class__.__name__:
|
|
578
579
|
init.zeros_(module.weight)
|
|
580
|
+
elif isinstance(module, Gemma3TextScaledWordEmbedding):
|
|
581
|
+
init.constant_(module.embed_scale, module.scalar_embed_scale)
|
|
582
|
+
elif isinstance(module, Gemma3RotaryEmbedding):
|
|
583
|
+
for layer_type in module.layer_types:
|
|
584
|
+
rope_init_fn = module.compute_default_rope_parameters
|
|
585
|
+
if module.rope_type[layer_type] != "default":
|
|
586
|
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
|
587
|
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
588
|
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
589
|
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
579
590
|
|
|
580
591
|
|
|
581
592
|
def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
|
|
@@ -734,6 +745,7 @@ def create_causal_mask_mapping(
|
|
|
734
745
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
735
746
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
736
747
|
is_training: bool = False,
|
|
748
|
+
is_first_iteration: Optional[bool] = None,
|
|
737
749
|
**kwargs,
|
|
738
750
|
) -> dict:
|
|
739
751
|
"""
|
|
@@ -756,8 +768,12 @@ def create_causal_mask_mapping(
|
|
|
756
768
|
# NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
|
|
757
769
|
# (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
|
|
758
770
|
# means). Determining prefill in that case requires checking data values, which is not compile-compatible.
|
|
759
|
-
|
|
760
|
-
|
|
771
|
+
is_first_iteration = (
|
|
772
|
+
is_first_iteration
|
|
773
|
+
if is_first_iteration is not None
|
|
774
|
+
else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
|
|
775
|
+
)
|
|
776
|
+
if token_type_ids is not None and is_first_iteration:
|
|
761
777
|
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
|
|
762
778
|
# undo the causal masking)
|
|
763
779
|
|
|
@@ -1005,6 +1021,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
|
|
1005
1021
|
use_cache=True,
|
|
1006
1022
|
logits_to_keep=None,
|
|
1007
1023
|
labels=None,
|
|
1024
|
+
is_first_iteration=False,
|
|
1008
1025
|
**kwargs,
|
|
1009
1026
|
):
|
|
1010
1027
|
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
|
@@ -1018,12 +1035,15 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
|
|
1018
1035
|
use_cache=use_cache,
|
|
1019
1036
|
logits_to_keep=logits_to_keep,
|
|
1020
1037
|
token_type_ids=token_type_ids,
|
|
1038
|
+
is_first_iteration=is_first_iteration,
|
|
1021
1039
|
**kwargs,
|
|
1022
1040
|
)
|
|
1023
1041
|
|
|
1024
|
-
#
|
|
1025
|
-
#
|
|
1026
|
-
|
|
1042
|
+
# Pixel values are used only in the first iteration if available
|
|
1043
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
1044
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
1045
|
+
# iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
|
|
1046
|
+
if is_first_iteration or not use_cache:
|
|
1027
1047
|
model_inputs["pixel_values"] = pixel_values
|
|
1028
1048
|
|
|
1029
1049
|
return model_inputs
|
|
@@ -495,6 +495,9 @@ class Gemma3nVisionConfig(PreTrainedConfig):
|
|
|
495
495
|
|
|
496
496
|
@classmethod
|
|
497
497
|
def from_dict(cls, config_dict: dict[str, Any], **kwargs):
|
|
498
|
+
# Create a copy to avoid mutating the original dict
|
|
499
|
+
config_dict = config_dict.copy()
|
|
500
|
+
|
|
498
501
|
label_names = config_dict.get("label_names")
|
|
499
502
|
is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
|
|
500
503
|
|
|
@@ -329,6 +329,16 @@ class Gemma3nAudioAttention(nn.Module):
|
|
|
329
329
|
r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
|
|
330
330
|
self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
|
|
331
331
|
|
|
332
|
+
local_causal_valid_mask = self.create_local_causal_valid_mask()
|
|
333
|
+
self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
|
|
334
|
+
|
|
335
|
+
self.register_buffer(
|
|
336
|
+
"softcap",
|
|
337
|
+
torch.tensor(self.attention_logits_soft_cap).float(),
|
|
338
|
+
persistent=False,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
def create_local_causal_valid_mask(self):
|
|
332
342
|
lower_causal_mask = torch.tril(
|
|
333
343
|
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
|
|
334
344
|
diagonal=0,
|
|
@@ -339,13 +349,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
|
339
349
|
)
|
|
340
350
|
local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
|
|
341
351
|
local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
self.register_buffer(
|
|
345
|
-
"softcap",
|
|
346
|
-
torch.tensor(self.attention_logits_soft_cap).float(),
|
|
347
|
-
persistent=False,
|
|
348
|
-
)
|
|
352
|
+
return local_causal_valid_mask
|
|
349
353
|
|
|
350
354
|
def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
|
|
351
355
|
batch, _, *tail_shape = x.shape
|
|
@@ -919,6 +923,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
|
|
919
923
|
self.conformer = nn.ModuleList(
|
|
920
924
|
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
|
|
921
925
|
)
|
|
926
|
+
self.post_init()
|
|
922
927
|
|
|
923
928
|
def forward(
|
|
924
929
|
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs
|
|
@@ -983,6 +988,7 @@ class Gemma3nTextScaledWordEmbedding(nn.Embedding):
|
|
|
983
988
|
|
|
984
989
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
|
985
990
|
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
|
991
|
+
self.scalar_embed_scale = embed_scale
|
|
986
992
|
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
|
987
993
|
|
|
988
994
|
def forward(self, input_ids: torch.Tensor):
|
|
@@ -1449,8 +1455,38 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
|
|
|
1449
1455
|
init.ones_(module.weight)
|
|
1450
1456
|
elif isinstance(module, Gemma3nAudioAttention):
|
|
1451
1457
|
init.zeros_(module.per_dim_scale)
|
|
1458
|
+
q_scale = module.head_dim**-0.5
|
|
1459
|
+
r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
|
|
1460
|
+
init.copy_(module.q_scale, q_scale * r_softplus_0)
|
|
1461
|
+
init.constant_(module.softcap, module.attention_logits_soft_cap)
|
|
1462
|
+
init.copy_(module.local_causal_valid_mask, module.create_local_causal_valid_mask())
|
|
1463
|
+
elif isinstance(module, Gemma3nTextScaledWordEmbedding):
|
|
1464
|
+
init.constant_(module.embed_scale, module.scalar_embed_scale)
|
|
1452
1465
|
elif isinstance(module, Gemma3nTextAltUp):
|
|
1453
1466
|
init.zeros_(module.correct_output_scale)
|
|
1467
|
+
init.constant_(module.router_input_scale, self.config.hidden_size**-1.0)
|
|
1468
|
+
elif isinstance(module, Gemma3nAudioRelativePositionEmbedding):
|
|
1469
|
+
min_timescale, max_timescale = 1.0, 1.0e4
|
|
1470
|
+
num_timescales = module.channels // 2
|
|
1471
|
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
|
|
1472
|
+
num_timescales - 1, 1
|
|
1473
|
+
)
|
|
1474
|
+
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
|
|
1475
|
+
init.copy_(module.inv_timescales, inv_timescales.float().unsqueeze(0).unsqueeze(0))
|
|
1476
|
+
elif isinstance(module, Gemma3nTextModel):
|
|
1477
|
+
init.constant_(module.per_layer_projection_scale, self.hidden_size**-0.5)
|
|
1478
|
+
init.constant_(module.per_layer_input_scale, 1 / math.sqrt(2.0))
|
|
1479
|
+
elif isinstance(module, Gemma3nRotaryEmbedding):
|
|
1480
|
+
for layer_type in module.layer_types:
|
|
1481
|
+
rope_init_fn = module.compute_default_rope_parameters
|
|
1482
|
+
if module.rope_type[layer_type] != "default":
|
|
1483
|
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
|
1484
|
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
1485
|
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
1486
|
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
1487
|
+
|
|
1488
|
+
if hasattr(module, "gradient_clipping"):
|
|
1489
|
+
init.constant_(module.gradient_clipping, self.config.gradient_clipping)
|
|
1454
1490
|
|
|
1455
1491
|
|
|
1456
1492
|
class Gemma3nRotaryEmbedding(nn.Module):
|
|
@@ -1476,7 +1512,7 @@ class Gemma3nRotaryEmbedding(nn.Module):
|
|
|
1476
1512
|
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
|
|
1477
1513
|
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
|
|
1478
1514
|
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
|
|
1479
|
-
|
|
1515
|
+
self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
|
|
1480
1516
|
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
|
|
1481
1517
|
|
|
1482
1518
|
@staticmethod
|
|
@@ -2301,6 +2337,7 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
|
2301
2337
|
use_cache=True,
|
|
2302
2338
|
logits_to_keep=None,
|
|
2303
2339
|
labels=None,
|
|
2340
|
+
is_first_iteration=False,
|
|
2304
2341
|
**kwargs,
|
|
2305
2342
|
):
|
|
2306
2343
|
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
|
@@ -2314,13 +2351,14 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
|
2314
2351
|
use_cache=use_cache,
|
|
2315
2352
|
logits_to_keep=logits_to_keep,
|
|
2316
2353
|
token_type_ids=token_type_ids,
|
|
2354
|
+
is_first_iteration=is_first_iteration,
|
|
2317
2355
|
**kwargs,
|
|
2318
2356
|
)
|
|
2319
2357
|
|
|
2320
2358
|
# If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
|
|
2321
2359
|
# tokens anymore. Otherwise multimodal inputs should be passed to model.
|
|
2322
2360
|
# NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
|
|
2323
|
-
if
|
|
2361
|
+
if is_first_iteration or not use_cache:
|
|
2324
2362
|
model_inputs["pixel_values"] = pixel_values
|
|
2325
2363
|
model_inputs["input_features"] = input_features
|
|
2326
2364
|
model_inputs["input_features_mask"] = input_features_mask
|
|
@@ -27,7 +27,7 @@ from ...cache_utils import Cache, DynamicCache
|
|
|
27
27
|
from ...configuration_utils import PreTrainedConfig, layer_type_validation
|
|
28
28
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
29
29
|
from ...modeling_outputs import BaseModelOutputWithPast
|
|
30
|
-
from ...modeling_rope_utils import RopeParameters
|
|
30
|
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
|
|
31
31
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
32
32
|
from ...processing_utils import Unpack
|
|
33
33
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
@@ -45,6 +45,7 @@ from ..gemma3.modeling_gemma3 import (
|
|
|
45
45
|
Gemma3DecoderLayer,
|
|
46
46
|
Gemma3ForCausalLM,
|
|
47
47
|
Gemma3RMSNorm,
|
|
48
|
+
Gemma3RotaryEmbedding,
|
|
48
49
|
Gemma3TextModel,
|
|
49
50
|
Gemma3TextScaledWordEmbedding,
|
|
50
51
|
)
|
|
@@ -882,6 +883,16 @@ class Gemma3nAudioAttention(nn.Module):
|
|
|
882
883
|
r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
|
|
883
884
|
self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
|
|
884
885
|
|
|
886
|
+
local_causal_valid_mask = self.create_local_causal_valid_mask()
|
|
887
|
+
self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
|
|
888
|
+
|
|
889
|
+
self.register_buffer(
|
|
890
|
+
"softcap",
|
|
891
|
+
torch.tensor(self.attention_logits_soft_cap).float(),
|
|
892
|
+
persistent=False,
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
def create_local_causal_valid_mask(self):
|
|
885
896
|
lower_causal_mask = torch.tril(
|
|
886
897
|
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
|
|
887
898
|
diagonal=0,
|
|
@@ -892,13 +903,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
|
892
903
|
)
|
|
893
904
|
local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
|
|
894
905
|
local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
self.register_buffer(
|
|
898
|
-
"softcap",
|
|
899
|
-
torch.tensor(self.attention_logits_soft_cap).float(),
|
|
900
|
-
persistent=False,
|
|
901
|
-
)
|
|
906
|
+
return local_causal_valid_mask
|
|
902
907
|
|
|
903
908
|
def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
|
|
904
909
|
batch, _, *tail_shape = x.shape
|
|
@@ -1472,6 +1477,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
|
|
1472
1477
|
self.conformer = nn.ModuleList(
|
|
1473
1478
|
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
|
|
1474
1479
|
)
|
|
1480
|
+
self.post_init()
|
|
1475
1481
|
|
|
1476
1482
|
def forward(
|
|
1477
1483
|
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs
|
|
@@ -1892,8 +1898,42 @@ class Gemma3nPreTrainedModel(Gemma2PreTrainedModel):
|
|
|
1892
1898
|
init.ones_(module.weight)
|
|
1893
1899
|
elif isinstance(module, Gemma3nAudioAttention):
|
|
1894
1900
|
init.zeros_(module.per_dim_scale)
|
|
1901
|
+
q_scale = module.head_dim**-0.5
|
|
1902
|
+
r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
|
|
1903
|
+
init.copy_(module.q_scale, q_scale * r_softplus_0)
|
|
1904
|
+
init.constant_(module.softcap, module.attention_logits_soft_cap)
|
|
1905
|
+
init.copy_(module.local_causal_valid_mask, module.create_local_causal_valid_mask())
|
|
1906
|
+
elif isinstance(module, Gemma3nTextScaledWordEmbedding):
|
|
1907
|
+
init.constant_(module.embed_scale, module.scalar_embed_scale)
|
|
1895
1908
|
elif isinstance(module, Gemma3nTextAltUp):
|
|
1896
1909
|
init.zeros_(module.correct_output_scale)
|
|
1910
|
+
init.constant_(module.router_input_scale, self.config.hidden_size**-1.0)
|
|
1911
|
+
elif isinstance(module, Gemma3nAudioRelativePositionEmbedding):
|
|
1912
|
+
min_timescale, max_timescale = 1.0, 1.0e4
|
|
1913
|
+
num_timescales = module.channels // 2
|
|
1914
|
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
|
|
1915
|
+
num_timescales - 1, 1
|
|
1916
|
+
)
|
|
1917
|
+
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
|
|
1918
|
+
init.copy_(module.inv_timescales, inv_timescales.float().unsqueeze(0).unsqueeze(0))
|
|
1919
|
+
elif isinstance(module, Gemma3nTextModel):
|
|
1920
|
+
init.constant_(module.per_layer_projection_scale, self.hidden_size**-0.5)
|
|
1921
|
+
init.constant_(module.per_layer_input_scale, 1 / math.sqrt(2.0))
|
|
1922
|
+
elif isinstance(module, Gemma3nRotaryEmbedding):
|
|
1923
|
+
for layer_type in module.layer_types:
|
|
1924
|
+
rope_init_fn = module.compute_default_rope_parameters
|
|
1925
|
+
if module.rope_type[layer_type] != "default":
|
|
1926
|
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
|
1927
|
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
|
1928
|
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
|
1929
|
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
|
1930
|
+
|
|
1931
|
+
if hasattr(module, "gradient_clipping"):
|
|
1932
|
+
init.constant_(module.gradient_clipping, self.config.gradient_clipping)
|
|
1933
|
+
|
|
1934
|
+
|
|
1935
|
+
class Gemma3nRotaryEmbedding(Gemma3RotaryEmbedding):
|
|
1936
|
+
pass
|
|
1897
1937
|
|
|
1898
1938
|
|
|
1899
1939
|
@auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.")
|
|
@@ -2543,6 +2583,7 @@ class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
|
|
2543
2583
|
use_cache=True,
|
|
2544
2584
|
logits_to_keep=None,
|
|
2545
2585
|
labels=None,
|
|
2586
|
+
is_first_iteration=False,
|
|
2546
2587
|
**kwargs,
|
|
2547
2588
|
):
|
|
2548
2589
|
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
|
@@ -2556,13 +2597,14 @@ class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
|
|
2556
2597
|
use_cache=use_cache,
|
|
2557
2598
|
logits_to_keep=logits_to_keep,
|
|
2558
2599
|
token_type_ids=token_type_ids,
|
|
2600
|
+
is_first_iteration=is_first_iteration,
|
|
2559
2601
|
**kwargs,
|
|
2560
2602
|
)
|
|
2561
2603
|
|
|
2562
2604
|
# If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
|
|
2563
2605
|
# tokens anymore. Otherwise multimodal inputs should be passed to model.
|
|
2564
2606
|
# NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
|
|
2565
|
-
if
|
|
2607
|
+
if is_first_iteration or not use_cache:
|
|
2566
2608
|
model_inputs["pixel_values"] = pixel_values
|
|
2567
2609
|
model_inputs["input_features"] = input_features
|
|
2568
2610
|
model_inputs["input_features_mask"] = input_features_mask
|