transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -23,7 +23,6 @@ from ... import initialization as init
|
|
|
23
23
|
from ...cache_utils import Cache, DynamicCache
|
|
24
24
|
from ...configuration_utils import PreTrainedConfig, layer_type_validation
|
|
25
25
|
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
|
26
|
-
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
27
26
|
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
|
|
28
27
|
from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
|
|
29
28
|
from ...modeling_rope_utils import (
|
|
@@ -34,6 +33,7 @@ from ...modeling_rope_utils import (
|
|
|
34
33
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
35
34
|
from ...processing_utils import Unpack
|
|
36
35
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
36
|
+
from ...utils.generic import maybe_autocast
|
|
37
37
|
from ..gemma2.configuration_gemma2 import Gemma2Config
|
|
38
38
|
from ..gemma2.modeling_gemma2 import (
|
|
39
39
|
Gemma2Attention,
|
|
@@ -438,7 +438,7 @@ class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding):
|
|
|
438
438
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
439
439
|
|
|
440
440
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
441
|
-
with
|
|
441
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
442
442
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
443
443
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
444
444
|
cos = emb.cos() * attention_scaling
|
|
@@ -465,7 +465,7 @@ class Gemma3Attention(Gemma2Attention):
|
|
|
465
465
|
attention_mask: Optional[torch.Tensor] = None,
|
|
466
466
|
past_key_values: Optional[Cache] = None,
|
|
467
467
|
cache_position: Optional[torch.LongTensor] = None,
|
|
468
|
-
**kwargs: Unpack[
|
|
468
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
469
469
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
470
470
|
input_shape = hidden_states.shape[:-1]
|
|
471
471
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
@@ -527,23 +527,19 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
|
|
527
527
|
attention_mask: Optional[torch.Tensor] = None,
|
|
528
528
|
position_ids: Optional[torch.LongTensor] = None,
|
|
529
529
|
past_key_values: Optional[Cache] = None,
|
|
530
|
-
output_attentions: Optional[bool] = False,
|
|
531
|
-
use_cache: Optional[bool] = False,
|
|
532
530
|
cache_position: Optional[torch.LongTensor] = None,
|
|
533
|
-
**kwargs,
|
|
531
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
534
532
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
535
533
|
residual = hidden_states
|
|
536
534
|
|
|
537
535
|
hidden_states = self.input_layernorm(hidden_states)
|
|
538
536
|
|
|
539
|
-
hidden_states,
|
|
537
|
+
hidden_states, _ = self.self_attn(
|
|
540
538
|
hidden_states=hidden_states,
|
|
541
539
|
position_embeddings=position_embeddings,
|
|
542
540
|
attention_mask=attention_mask,
|
|
543
541
|
position_ids=position_ids,
|
|
544
542
|
past_key_values=past_key_values,
|
|
545
|
-
output_attentions=output_attentions,
|
|
546
|
-
use_cache=use_cache,
|
|
547
543
|
cache_position=cache_position,
|
|
548
544
|
**kwargs,
|
|
549
545
|
)
|
|
@@ -556,12 +552,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
|
|
556
552
|
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
|
557
553
|
hidden_states = residual + hidden_states
|
|
558
554
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
if output_attentions:
|
|
562
|
-
outputs += (self_attn_weights,)
|
|
563
|
-
|
|
564
|
-
return outputs
|
|
555
|
+
return hidden_states
|
|
565
556
|
|
|
566
557
|
|
|
567
558
|
GEMMA3_START_DOCSTRING = None
|
|
@@ -620,30 +611,16 @@ class Gemma3TextModel(Gemma2Model):
|
|
|
620
611
|
past_key_values: Optional[Cache] = None,
|
|
621
612
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
622
613
|
use_cache: Optional[bool] = None,
|
|
623
|
-
output_attentions: Optional[bool] = None,
|
|
624
|
-
output_hidden_states: Optional[bool] = None,
|
|
625
614
|
cache_position: Optional[torch.LongTensor] = None,
|
|
626
615
|
**kwargs: Unpack[TransformersKwargs],
|
|
627
616
|
) -> BaseModelOutputWithPast:
|
|
628
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
629
|
-
output_hidden_states = (
|
|
630
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
631
|
-
)
|
|
632
|
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
633
|
-
|
|
634
617
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
635
618
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
636
619
|
|
|
637
|
-
if self.gradient_checkpointing and self.training and use_cache:
|
|
638
|
-
logger.warning_once(
|
|
639
|
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
640
|
-
)
|
|
641
|
-
use_cache = False
|
|
642
|
-
|
|
643
620
|
if inputs_embeds is None:
|
|
644
621
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
645
622
|
|
|
646
|
-
if use_cache and past_key_values is None
|
|
623
|
+
if use_cache and past_key_values is None:
|
|
647
624
|
past_key_values = DynamicCache(config=self.config)
|
|
648
625
|
|
|
649
626
|
if cache_position is None:
|
|
@@ -684,41 +661,22 @@ class Gemma3TextModel(Gemma2Model):
|
|
|
684
661
|
for layer_type in self.config.layer_types:
|
|
685
662
|
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
|
686
663
|
|
|
687
|
-
# decoder layers
|
|
688
|
-
all_hidden_states = () if output_hidden_states else None
|
|
689
|
-
all_self_attns = () if output_attentions else None
|
|
690
|
-
|
|
691
664
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
692
|
-
|
|
693
|
-
all_hidden_states += (hidden_states,)
|
|
694
|
-
|
|
695
|
-
layer_outputs = decoder_layer(
|
|
665
|
+
hidden_states = decoder_layer(
|
|
696
666
|
hidden_states,
|
|
697
667
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
|
698
668
|
position_embeddings=position_embeddings[decoder_layer.attention_type],
|
|
699
669
|
position_ids=position_ids,
|
|
700
670
|
past_key_values=past_key_values,
|
|
701
|
-
output_attentions=output_attentions,
|
|
702
|
-
use_cache=use_cache,
|
|
703
671
|
cache_position=cache_position,
|
|
704
672
|
**kwargs,
|
|
705
673
|
)
|
|
706
674
|
|
|
707
|
-
hidden_states = layer_outputs[0]
|
|
708
|
-
|
|
709
|
-
if output_attentions:
|
|
710
|
-
all_self_attns += (layer_outputs[1],)
|
|
711
|
-
|
|
712
675
|
hidden_states = self.norm(hidden_states)
|
|
713
676
|
|
|
714
|
-
if output_hidden_states:
|
|
715
|
-
all_hidden_states += (hidden_states,)
|
|
716
|
-
|
|
717
677
|
return BaseModelOutputWithPast(
|
|
718
678
|
last_hidden_state=hidden_states,
|
|
719
679
|
past_key_values=past_key_values,
|
|
720
|
-
hidden_states=all_hidden_states,
|
|
721
|
-
attentions=all_self_attns,
|
|
722
680
|
)
|
|
723
681
|
|
|
724
682
|
|
|
@@ -853,20 +811,11 @@ class Gemma3Model(PaliGemmaModel):
|
|
|
853
811
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
854
812
|
labels: Optional[torch.LongTensor] = None,
|
|
855
813
|
use_cache: Optional[bool] = None,
|
|
856
|
-
|
|
857
|
-
output_hidden_states: Optional[bool] = None,
|
|
858
|
-
return_dict: Optional[bool] = None,
|
|
859
|
-
**lm_kwargs,
|
|
814
|
+
**lm_kwargs: Unpack[TransformersKwargs],
|
|
860
815
|
) -> Union[tuple, Gemma3ModelOutputWithPast]:
|
|
861
816
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
862
817
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
863
818
|
|
|
864
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
865
|
-
output_hidden_states = (
|
|
866
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
867
|
-
)
|
|
868
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
869
|
-
|
|
870
819
|
# Replace image id with PAD if the image token if OOV, to avoid index-errors
|
|
871
820
|
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
|
872
821
|
special_image_mask = input_ids == self.config.image_token_id
|
|
@@ -913,8 +862,6 @@ class Gemma3Model(PaliGemmaModel):
|
|
|
913
862
|
past_key_values=past_key_values,
|
|
914
863
|
inputs_embeds=inputs_embeds,
|
|
915
864
|
use_cache=use_cache,
|
|
916
|
-
output_attentions=output_attentions,
|
|
917
|
-
output_hidden_states=output_hidden_states,
|
|
918
865
|
return_dict=True,
|
|
919
866
|
cache_position=cache_position,
|
|
920
867
|
**lm_kwargs,
|
|
@@ -922,7 +869,7 @@ class Gemma3Model(PaliGemmaModel):
|
|
|
922
869
|
|
|
923
870
|
return Gemma3ModelOutputWithPast(
|
|
924
871
|
last_hidden_state=outputs.last_hidden_state,
|
|
925
|
-
past_key_values=outputs.past_key_values
|
|
872
|
+
past_key_values=outputs.past_key_values,
|
|
926
873
|
hidden_states=outputs.hidden_states,
|
|
927
874
|
attentions=outputs.attentions,
|
|
928
875
|
image_hidden_states=image_features if pixel_values is not None else None,
|
|
@@ -934,6 +881,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
|
|
934
881
|
# Fix: https://github.com/huggingface/transformers/issues/40564
|
|
935
882
|
accepts_loss_kwargs = False
|
|
936
883
|
|
|
884
|
+
@can_return_tuple
|
|
937
885
|
@auto_docstring
|
|
938
886
|
def forward(
|
|
939
887
|
self,
|
|
@@ -947,11 +895,8 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
|
|
947
895
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
948
896
|
labels: Optional[torch.LongTensor] = None,
|
|
949
897
|
use_cache: Optional[bool] = None,
|
|
950
|
-
output_attentions: Optional[bool] = None,
|
|
951
|
-
output_hidden_states: Optional[bool] = None,
|
|
952
|
-
return_dict: Optional[bool] = None,
|
|
953
898
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
954
|
-
**lm_kwargs,
|
|
899
|
+
**lm_kwargs: Unpack[TransformersKwargs],
|
|
955
900
|
) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
|
|
956
901
|
r"""
|
|
957
902
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -997,13 +942,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
|
|
997
942
|
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
|
|
998
943
|
```
|
|
999
944
|
"""
|
|
1000
|
-
|
|
1001
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1002
|
-
output_hidden_states = (
|
|
1003
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1004
|
-
)
|
|
1005
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1006
|
-
|
|
1007
945
|
outputs = self.model(
|
|
1008
946
|
input_ids=input_ids,
|
|
1009
947
|
pixel_values=pixel_values,
|
|
@@ -1014,9 +952,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
|
|
1014
952
|
inputs_embeds=inputs_embeds,
|
|
1015
953
|
use_cache=use_cache,
|
|
1016
954
|
labels=labels,
|
|
1017
|
-
output_attentions=output_attentions,
|
|
1018
|
-
output_hidden_states=output_hidden_states,
|
|
1019
|
-
return_dict=return_dict,
|
|
1020
955
|
cache_position=cache_position,
|
|
1021
956
|
**lm_kwargs,
|
|
1022
957
|
)
|
|
@@ -1048,10 +983,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
|
|
1048
983
|
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
|
1049
984
|
loss = loss_fct(flat_logits, flat_labels)
|
|
1050
985
|
|
|
1051
|
-
if not return_dict:
|
|
1052
|
-
output = (logits,) + outputs[1:]
|
|
1053
|
-
return (loss,) + output if loss is not None else output
|
|
1054
|
-
|
|
1055
986
|
return Gemma3CausalLMOutputWithPast(
|
|
1056
987
|
loss=loss,
|
|
1057
988
|
logits=logits,
|
|
@@ -32,21 +32,19 @@ from ... import initialization as init
|
|
|
32
32
|
from ...activations import ACT2FN
|
|
33
33
|
from ...cache_utils import Cache, DynamicCache
|
|
34
34
|
from ...generation import GenerationMixin
|
|
35
|
+
from ...integrations import use_kernelized_func
|
|
35
36
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
36
|
-
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
37
37
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
38
38
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
39
39
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
40
40
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
41
41
|
from ...processing_utils import Unpack
|
|
42
|
-
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
|
|
42
|
+
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
|
|
43
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
43
44
|
from ..auto import AutoModel
|
|
44
45
|
from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig
|
|
45
46
|
|
|
46
47
|
|
|
47
|
-
logger = logging.get_logger(__name__)
|
|
48
|
-
|
|
49
|
-
|
|
50
48
|
@dataclass
|
|
51
49
|
@auto_docstring(
|
|
52
50
|
custom_intro="""
|
|
@@ -923,7 +921,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
|
|
923
921
|
)
|
|
924
922
|
|
|
925
923
|
def forward(
|
|
926
|
-
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
|
924
|
+
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs
|
|
927
925
|
) -> tuple[torch.Tensor, torch.BoolTensor]:
|
|
928
926
|
"""Encodes a batch of MELs.
|
|
929
927
|
|
|
@@ -1228,6 +1226,7 @@ def apply_rotary_pos_emb(
|
|
|
1228
1226
|
return (x * cos) + (rotate_half(x) * sin)
|
|
1229
1227
|
|
|
1230
1228
|
|
|
1229
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
1231
1230
|
class Gemma3nTextAttention(nn.Module):
|
|
1232
1231
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
1233
1232
|
|
|
@@ -1254,7 +1253,6 @@ class Gemma3nTextAttention(nn.Module):
|
|
|
1254
1253
|
self.o_proj = nn.Linear(
|
|
1255
1254
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
1256
1255
|
)
|
|
1257
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
1258
1256
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
1259
1257
|
self.is_sliding = self.layer_type == "sliding_attention"
|
|
1260
1258
|
|
|
@@ -1283,7 +1281,7 @@ class Gemma3nTextAttention(nn.Module):
|
|
|
1283
1281
|
attention_mask: Optional[torch.Tensor] = None,
|
|
1284
1282
|
past_key_values: Optional[Cache] = None,
|
|
1285
1283
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1286
|
-
**kwargs: Unpack[
|
|
1284
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1287
1285
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
1288
1286
|
input_shape = hidden_states.shape[:-1]
|
|
1289
1287
|
hidden_shape = (*input_shape, -1, self.config.head_dim)
|
|
@@ -1379,10 +1377,8 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
|
|
|
1379
1377
|
attention_mask: Optional[torch.Tensor] = None,
|
|
1380
1378
|
position_ids: Optional[torch.LongTensor] = None,
|
|
1381
1379
|
past_key_values: Optional[Cache] = None,
|
|
1382
|
-
output_attentions: Optional[bool] = False,
|
|
1383
|
-
use_cache: Optional[bool] = False,
|
|
1384
1380
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1385
|
-
**kwargs,
|
|
1381
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1386
1382
|
) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
1387
1383
|
predictions = self.altup.predict(hidden_states)
|
|
1388
1384
|
active_prediction = predictions[self.config.altup_active_idx]
|
|
@@ -1390,14 +1386,12 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
|
|
|
1390
1386
|
active_prediction_normed = self.input_layernorm(active_prediction)
|
|
1391
1387
|
laurel_output = self.laurel(active_prediction_normed)
|
|
1392
1388
|
|
|
1393
|
-
attn,
|
|
1389
|
+
attn, _ = self.self_attn(
|
|
1394
1390
|
hidden_states=active_prediction_normed,
|
|
1395
1391
|
attention_mask=attention_mask,
|
|
1396
1392
|
position_ids=position_ids,
|
|
1397
1393
|
position_embeddings=position_embeddings,
|
|
1398
1394
|
past_key_values=past_key_values,
|
|
1399
|
-
output_attentions=output_attentions,
|
|
1400
|
-
use_cache=use_cache,
|
|
1401
1395
|
cache_position=cache_position,
|
|
1402
1396
|
**kwargs,
|
|
1403
1397
|
)
|
|
@@ -1426,154 +1420,7 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
|
|
|
1426
1420
|
first_prediction = self.post_per_layer_input_norm(first_prediction)
|
|
1427
1421
|
corrected_predictions[1:] += first_prediction
|
|
1428
1422
|
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
if output_attentions:
|
|
1432
|
-
outputs += (self_attn_weights,)
|
|
1433
|
-
|
|
1434
|
-
return outputs
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
class Gemma3nMLP(nn.Module):
|
|
1438
|
-
def __init__(self, config):
|
|
1439
|
-
super().__init__()
|
|
1440
|
-
self.config = config
|
|
1441
|
-
self.hidden_size = config.hidden_size
|
|
1442
|
-
self.intermediate_size = config.intermediate_size
|
|
1443
|
-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
1444
|
-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
1445
|
-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
1446
|
-
self.act_fn = ACT2FN[config.hidden_activation]
|
|
1447
|
-
|
|
1448
|
-
def forward(self, x):
|
|
1449
|
-
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
1450
|
-
return down_proj
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
class Gemma3nAttention(nn.Module):
|
|
1454
|
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
1455
|
-
|
|
1456
|
-
def __init__(self, config: Gemma3nConfig, layer_idx: int):
|
|
1457
|
-
super().__init__()
|
|
1458
|
-
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
|
|
1459
|
-
self.config = config
|
|
1460
|
-
self.layer_idx = layer_idx
|
|
1461
|
-
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
1462
|
-
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
|
1463
|
-
self.scaling = config.query_pre_attn_scalar**-0.5
|
|
1464
|
-
self.attention_dropout = self.config.attention_dropout
|
|
1465
|
-
self.is_causal = not getattr(config, "use_bidirectional_attention", False)
|
|
1466
|
-
|
|
1467
|
-
self.q_proj = nn.Linear(
|
|
1468
|
-
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
|
1469
|
-
)
|
|
1470
|
-
self.k_proj = nn.Linear(
|
|
1471
|
-
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
1472
|
-
)
|
|
1473
|
-
self.v_proj = nn.Linear(
|
|
1474
|
-
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
1475
|
-
)
|
|
1476
|
-
self.o_proj = nn.Linear(
|
|
1477
|
-
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
1478
|
-
)
|
|
1479
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
1480
|
-
self.attn_logit_softcapping = self.config.attn_logit_softcapping
|
|
1481
|
-
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
1482
|
-
|
|
1483
|
-
def forward(
|
|
1484
|
-
self,
|
|
1485
|
-
hidden_states: torch.Tensor,
|
|
1486
|
-
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
1487
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
1488
|
-
past_key_values: Optional[Cache] = None,
|
|
1489
|
-
cache_position: Optional[torch.LongTensor] = None,
|
|
1490
|
-
**kwargs: Unpack[FlashAttentionKwargs],
|
|
1491
|
-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
1492
|
-
input_shape = hidden_states.shape[:-1]
|
|
1493
|
-
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
1494
|
-
|
|
1495
|
-
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
1496
|
-
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
1497
|
-
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
1498
|
-
|
|
1499
|
-
cos, sin = position_embeddings
|
|
1500
|
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
1501
|
-
|
|
1502
|
-
if past_key_values is not None:
|
|
1503
|
-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
1504
|
-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
1505
|
-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
1506
|
-
|
|
1507
|
-
attention_interface: Callable = eager_attention_forward
|
|
1508
|
-
if self.config._attn_implementation != "eager":
|
|
1509
|
-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
1510
|
-
|
|
1511
|
-
attn_output, attn_weights = attention_interface(
|
|
1512
|
-
self,
|
|
1513
|
-
query_states,
|
|
1514
|
-
key_states,
|
|
1515
|
-
value_states,
|
|
1516
|
-
attention_mask,
|
|
1517
|
-
dropout=self.attention_dropout if self.training else 0.0,
|
|
1518
|
-
scaling=self.scaling,
|
|
1519
|
-
sliding_window=self.sliding_window,
|
|
1520
|
-
softcap=self.attn_logit_softcapping,
|
|
1521
|
-
**kwargs,
|
|
1522
|
-
)
|
|
1523
|
-
|
|
1524
|
-
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
1525
|
-
attn_output = self.o_proj(attn_output)
|
|
1526
|
-
return attn_output, attn_weights
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
class Gemma3nDecoderLayer(GradientCheckpointingLayer):
|
|
1530
|
-
def __init__(self, config: Gemma3nConfig, layer_idx: int):
|
|
1531
|
-
super().__init__()
|
|
1532
|
-
self.hidden_size = config.hidden_size
|
|
1533
|
-
self.config = config
|
|
1534
|
-
self.attention_type = config.layer_types[layer_idx]
|
|
1535
|
-
self.self_attn = Gemma3nAttention(config=config, layer_idx=layer_idx)
|
|
1536
|
-
self.mlp = Gemma3nMLP(config)
|
|
1537
|
-
self.input_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
1538
|
-
self.post_attention_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
1539
|
-
|
|
1540
|
-
self.pre_feedforward_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
1541
|
-
self.post_feedforward_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
1542
|
-
|
|
1543
|
-
def forward(
|
|
1544
|
-
self,
|
|
1545
|
-
hidden_states: torch.Tensor,
|
|
1546
|
-
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
1547
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
1548
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
1549
|
-
past_key_values: Optional[Cache] = None,
|
|
1550
|
-
cache_position: Optional[torch.LongTensor] = None,
|
|
1551
|
-
**kwargs,
|
|
1552
|
-
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
1553
|
-
residual = hidden_states
|
|
1554
|
-
|
|
1555
|
-
hidden_states = self.input_layernorm(hidden_states)
|
|
1556
|
-
|
|
1557
|
-
# Self Attention
|
|
1558
|
-
hidden_states, _ = self.self_attn(
|
|
1559
|
-
hidden_states=hidden_states,
|
|
1560
|
-
position_embeddings=position_embeddings,
|
|
1561
|
-
attention_mask=attention_mask,
|
|
1562
|
-
position_ids=position_ids,
|
|
1563
|
-
past_key_values=past_key_values,
|
|
1564
|
-
cache_position=cache_position,
|
|
1565
|
-
**kwargs,
|
|
1566
|
-
)
|
|
1567
|
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
1568
|
-
hidden_states = residual + hidden_states
|
|
1569
|
-
|
|
1570
|
-
residual = hidden_states
|
|
1571
|
-
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
|
1572
|
-
hidden_states = self.mlp(hidden_states)
|
|
1573
|
-
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
|
1574
|
-
hidden_states = residual + hidden_states
|
|
1575
|
-
|
|
1576
|
-
return hidden_states
|
|
1423
|
+
return corrected_predictions
|
|
1577
1424
|
|
|
1578
1425
|
|
|
1579
1426
|
@auto_docstring
|
|
@@ -1590,8 +1437,8 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
|
|
|
1590
1437
|
_can_compile_fullgraph = True
|
|
1591
1438
|
_supports_attention_backend = True
|
|
1592
1439
|
_can_record_outputs = {
|
|
1593
|
-
"hidden_states":
|
|
1594
|
-
"attentions":
|
|
1440
|
+
"hidden_states": Gemma3nTextDecoderLayer,
|
|
1441
|
+
"attentions": Gemma3nTextAttention,
|
|
1595
1442
|
}
|
|
1596
1443
|
input_modalities = ("image", "text", "audio")
|
|
1597
1444
|
|
|
@@ -1678,7 +1525,7 @@ class Gemma3nRotaryEmbedding(nn.Module):
|
|
|
1678
1525
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
1679
1526
|
|
|
1680
1527
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
1681
|
-
with
|
|
1528
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
1682
1529
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
1683
1530
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1684
1531
|
cos = emb.cos() * attention_scaling
|
|
@@ -1741,7 +1588,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
|
1741
1588
|
# Initialize weights and apply final processing
|
|
1742
1589
|
self.post_init()
|
|
1743
1590
|
|
|
1744
|
-
@
|
|
1591
|
+
@check_model_inputs(tie_last_hidden_states=False)
|
|
1745
1592
|
@auto_docstring
|
|
1746
1593
|
def forward(
|
|
1747
1594
|
self,
|
|
@@ -1752,8 +1599,6 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
|
1752
1599
|
past_key_values: Optional[Cache] = None,
|
|
1753
1600
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
1754
1601
|
use_cache: Optional[bool] = None,
|
|
1755
|
-
output_attentions: Optional[bool] = None,
|
|
1756
|
-
output_hidden_states: Optional[bool] = None,
|
|
1757
1602
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1758
1603
|
**kwargs: Unpack[TransformersKwargs],
|
|
1759
1604
|
) -> BaseModelOutputWithPast:
|
|
@@ -1761,37 +1606,21 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
|
1761
1606
|
per_layer_inputs (torch.Tensor, *optional*, defaults to None):
|
|
1762
1607
|
Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
|
|
1763
1608
|
"""
|
|
1764
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1765
|
-
output_hidden_states = (
|
|
1766
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1767
|
-
)
|
|
1768
|
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
1769
|
-
|
|
1770
1609
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
1771
1610
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
1772
1611
|
|
|
1773
|
-
if self.gradient_checkpointing and self.training and use_cache:
|
|
1774
|
-
logger.warning_once(
|
|
1775
|
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
1776
|
-
)
|
|
1777
|
-
use_cache = False
|
|
1778
|
-
|
|
1779
1612
|
if input_ids is not None:
|
|
1780
1613
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
1781
1614
|
per_layer_inputs = self.get_per_layer_inputs(input_ids)
|
|
1782
1615
|
|
|
1783
1616
|
per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
|
|
1784
1617
|
|
|
1785
|
-
if use_cache and past_key_values is None
|
|
1618
|
+
if use_cache and past_key_values is None:
|
|
1786
1619
|
past_key_values = DynamicCache(config=self.config)
|
|
1787
1620
|
|
|
1788
1621
|
if cache_position is None:
|
|
1789
1622
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
1790
|
-
cache_position = torch.arange(
|
|
1791
|
-
past_seen_tokens,
|
|
1792
|
-
past_seen_tokens + inputs_embeds.shape[1],
|
|
1793
|
-
device=inputs_embeds.device,
|
|
1794
|
-
)
|
|
1623
|
+
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
|
|
1795
1624
|
|
|
1796
1625
|
if position_ids is None:
|
|
1797
1626
|
position_ids = cache_position.unsqueeze(0)
|
|
@@ -1835,39 +1664,21 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
|
1835
1664
|
for layer_type in self.config.layer_types:
|
|
1836
1665
|
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
|
1837
1666
|
|
|
1838
|
-
# decoder layers
|
|
1839
|
-
all_hidden_states = () if output_hidden_states else None
|
|
1840
|
-
all_self_attns = () if output_attentions else None
|
|
1841
|
-
|
|
1842
1667
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
1843
|
-
if output_hidden_states:
|
|
1844
|
-
all_hidden_states += (hidden_states,)
|
|
1845
|
-
|
|
1846
1668
|
causal_mask = causal_mask_mapping[decoder_layer.attention_type]
|
|
1847
1669
|
per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :]
|
|
1848
1670
|
|
|
1849
|
-
|
|
1671
|
+
hidden_states = decoder_layer(
|
|
1850
1672
|
hidden_states,
|
|
1851
1673
|
position_embeddings[decoder_layer.attention_type],
|
|
1852
1674
|
per_layer_input,
|
|
1853
1675
|
attention_mask=causal_mask,
|
|
1854
1676
|
position_ids=position_ids,
|
|
1855
1677
|
past_key_values=past_key_values,
|
|
1856
|
-
output_attentions=output_attentions,
|
|
1857
|
-
use_cache=use_cache,
|
|
1858
1678
|
cache_position=cache_position,
|
|
1859
1679
|
**kwargs,
|
|
1860
1680
|
)
|
|
1861
1681
|
|
|
1862
|
-
hidden_states = layer_outputs[0]
|
|
1863
|
-
|
|
1864
|
-
if output_attentions:
|
|
1865
|
-
all_self_attns += (layer_outputs[1],)
|
|
1866
|
-
|
|
1867
|
-
# add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output)
|
|
1868
|
-
if output_hidden_states:
|
|
1869
|
-
all_hidden_states += (hidden_states,)
|
|
1870
|
-
|
|
1871
1682
|
# Per-layer inputs to single output
|
|
1872
1683
|
target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
|
|
1873
1684
|
temp_hidden_states = [hidden_states[0]]
|
|
@@ -1887,8 +1698,6 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
|
1887
1698
|
return BaseModelOutputWithPast(
|
|
1888
1699
|
last_hidden_state=hidden_states,
|
|
1889
1700
|
past_key_values=past_key_values,
|
|
1890
|
-
hidden_states=all_hidden_states,
|
|
1891
|
-
attentions=all_self_attns,
|
|
1892
1701
|
)
|
|
1893
1702
|
|
|
1894
1703
|
def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
|
@@ -2175,7 +1984,7 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
|
|
|
2175
1984
|
use_cache: Optional[bool] = None,
|
|
2176
1985
|
output_attentions: Optional[bool] = None,
|
|
2177
1986
|
output_hidden_states: Optional[bool] = None,
|
|
2178
|
-
**lm_kwargs,
|
|
1987
|
+
**lm_kwargs: Unpack[TransformersKwargs],
|
|
2179
1988
|
) -> Gemma3nCausalLMOutputWithPast:
|
|
2180
1989
|
r"""
|
|
2181
1990
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -2363,7 +2172,7 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
|
2363
2172
|
output_attentions: Optional[bool] = None,
|
|
2364
2173
|
output_hidden_states: Optional[bool] = None,
|
|
2365
2174
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
2366
|
-
**lm_kwargs,
|
|
2175
|
+
**lm_kwargs: Unpack[TransformersKwargs],
|
|
2367
2176
|
) -> Gemma3nCausalLMOutputWithPast:
|
|
2368
2177
|
r"""
|
|
2369
2178
|
input_features_mask (torch.Tensor, *optional*, defaults to None):
|