transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Tokenization class for Funnel Transformer."""
|
|
16
16
|
|
|
17
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
|
20
20
|
from tokenizers.models import WordPiece
|
|
@@ -83,16 +83,17 @@ class FunnelTokenizer(TokenizersBackend):
|
|
|
83
83
|
value for `lowercase` (as in the original BERT).
|
|
84
84
|
wordpieces_prefix (`str`, *optional*, defaults to `"##"`):
|
|
85
85
|
The prefix for subwords.
|
|
86
|
-
vocab (`dict`, *optional*):
|
|
86
|
+
vocab (`str` or `dict[str, int]`, *optional*):
|
|
87
87
|
Custom vocabulary dictionary.
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
90
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
91
|
-
|
|
91
|
+
model = WordPiece
|
|
92
92
|
cls_token_type_id: int = 2
|
|
93
93
|
|
|
94
94
|
def __init__(
|
|
95
95
|
self,
|
|
96
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
96
97
|
do_lower_case: bool = True,
|
|
97
98
|
unk_token: str = "<unk>",
|
|
98
99
|
sep_token: str = "<sep>",
|
|
@@ -105,23 +106,18 @@ class FunnelTokenizer(TokenizersBackend):
|
|
|
105
106
|
tokenize_chinese_chars: bool = True,
|
|
106
107
|
strip_accents: Optional[bool] = None,
|
|
107
108
|
wordpieces_prefix: str = "##",
|
|
108
|
-
vocab: Optional[dict] = None,
|
|
109
|
-
vocab_file: Optional[str] = None,
|
|
110
109
|
**kwargs,
|
|
111
110
|
):
|
|
112
|
-
self.vocab_file = vocab_file
|
|
113
111
|
self.do_lower_case = do_lower_case
|
|
114
112
|
self.tokenize_chinese_chars = tokenize_chinese_chars
|
|
115
113
|
self.strip_accents = strip_accents
|
|
116
114
|
self.clean_text = clean_text
|
|
117
115
|
self.wordpieces_prefix = wordpieces_prefix
|
|
118
116
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
else:
|
|
124
|
-
self._vocab = {
|
|
117
|
+
self._vocab = (
|
|
118
|
+
vocab
|
|
119
|
+
if vocab is not None
|
|
120
|
+
else {
|
|
125
121
|
str(pad_token): 0,
|
|
126
122
|
str(unk_token): 1,
|
|
127
123
|
str(cls_token): 2,
|
|
@@ -130,6 +126,7 @@ class FunnelTokenizer(TokenizersBackend):
|
|
|
130
126
|
str(bos_token): 5,
|
|
131
127
|
str(eos_token): 6,
|
|
132
128
|
}
|
|
129
|
+
)
|
|
133
130
|
|
|
134
131
|
self._tokenizer = Tokenizer(WordPiece(self._vocab, unk_token=str(unk_token)))
|
|
135
132
|
|
|
@@ -142,19 +139,7 @@ class FunnelTokenizer(TokenizersBackend):
|
|
|
142
139
|
self._tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
|
143
140
|
self._tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
|
|
144
141
|
|
|
145
|
-
self._tokenizer.post_processor = processors.TemplateProcessing(
|
|
146
|
-
single=f"{cls_token}:2 $A:0 {sep_token}:0", # token_type_id is 2 for Funnel transformer
|
|
147
|
-
pair=f"{cls_token}:2 $A:0 {sep_token}:0 $B:1 {sep_token}:1",
|
|
148
|
-
special_tokens=[
|
|
149
|
-
(str(cls_token), self._vocab.get(str(cls_token), 2)),
|
|
150
|
-
(str(sep_token), self._vocab.get(str(sep_token), 3)),
|
|
151
|
-
],
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
tokenizer_object = self._tokenizer
|
|
155
|
-
|
|
156
142
|
super().__init__(
|
|
157
|
-
tokenizer_object=tokenizer_object,
|
|
158
143
|
do_lower_case=do_lower_case,
|
|
159
144
|
unk_token=unk_token,
|
|
160
145
|
sep_token=sep_token,
|
|
@@ -169,6 +154,14 @@ class FunnelTokenizer(TokenizersBackend):
|
|
|
169
154
|
wordpieces_prefix=wordpieces_prefix,
|
|
170
155
|
**kwargs,
|
|
171
156
|
)
|
|
157
|
+
self._tokenizer.post_processor = processors.TemplateProcessing(
|
|
158
|
+
single=f"{cls_token}:2 $A:0 {sep_token}:0", # token_type_id is 2 for Funnel transformer
|
|
159
|
+
pair=f"{cls_token}:2 $A:0 {sep_token}:0 $B:1 {sep_token}:1",
|
|
160
|
+
special_tokens=[
|
|
161
|
+
(str(cls_token), self.cls_token_id),
|
|
162
|
+
(str(sep_token), self.sep_token_id),
|
|
163
|
+
],
|
|
164
|
+
)
|
|
172
165
|
|
|
173
166
|
|
|
174
167
|
__all__ = ["FunnelTokenizer"]
|
|
@@ -337,13 +337,13 @@ class FuyuProcessor(ProcessorMixin):
|
|
|
337
337
|
r"""
|
|
338
338
|
Constructs a Fuyu processor which wraps a Fuyu image processor and a Llama tokenizer into a single processor.
|
|
339
339
|
|
|
340
|
-
[`FuyuProcessor`] offers all the functionalities of [`FuyuImageProcessor`] and [`
|
|
340
|
+
[`FuyuProcessor`] offers all the functionalities of [`FuyuImageProcessor`] and [`TokenizersBackend`]. See the
|
|
341
341
|
[`~FuyuProcessor.__call__`] and [`~FuyuProcessor.decode`] for more information.
|
|
342
342
|
|
|
343
343
|
Args:
|
|
344
344
|
image_processor ([`FuyuImageProcessor`]):
|
|
345
345
|
The image processor is a required input.
|
|
346
|
-
tokenizer ([`
|
|
346
|
+
tokenizer ([`TokenizersBackend`]):
|
|
347
347
|
The tokenizer is a required input.
|
|
348
348
|
"""
|
|
349
349
|
|
|
@@ -486,7 +486,7 @@ class FuyuProcessor(ProcessorMixin):
|
|
|
486
486
|
) -> "FuyuBatchFeature":
|
|
487
487
|
"""
|
|
488
488
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
|
489
|
-
and `kwargs` arguments to
|
|
489
|
+
and `kwargs` arguments to TokenizersBackend's [`~TokenizersBackend.__call__`] if `text` is not `None` to
|
|
490
490
|
encode the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
|
|
491
491
|
FuyuImageProcessor's [`~FuyuImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
|
492
492
|
of the above two methods for more information.
|
|
@@ -29,7 +29,7 @@ from ... import initialization as init
|
|
|
29
29
|
from ...activations import ACT2FN
|
|
30
30
|
from ...cache_utils import Cache, DynamicCache
|
|
31
31
|
from ...generation import GenerationMixin
|
|
32
|
-
from ...integrations import use_kernel_func_from_hub
|
|
32
|
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
|
33
33
|
from ...masking_utils import create_causal_mask
|
|
34
34
|
from ...modeling_layers import (
|
|
35
35
|
GenericForSequenceClassification,
|
|
@@ -41,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
41
41
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
42
42
|
from ...processing_utils import Unpack
|
|
43
43
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
44
|
-
from ...utils.generic import check_model_inputs
|
|
44
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
45
45
|
from .configuration_gemma import GemmaConfig
|
|
46
46
|
|
|
47
47
|
|
|
@@ -137,7 +137,7 @@ class GemmaRotaryEmbedding(nn.Module):
|
|
|
137
137
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
138
138
|
|
|
139
139
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
140
|
-
with
|
|
140
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
141
141
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
142
142
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
143
143
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -219,6 +219,7 @@ def eager_attention_forward(
|
|
|
219
219
|
return attn_output, attn_weights
|
|
220
220
|
|
|
221
221
|
|
|
222
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
222
223
|
class GemmaAttention(nn.Module):
|
|
223
224
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
224
225
|
|
|
@@ -244,7 +245,6 @@ class GemmaAttention(nn.Module):
|
|
|
244
245
|
self.o_proj = nn.Linear(
|
|
245
246
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
246
247
|
)
|
|
247
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
248
248
|
|
|
249
249
|
def forward(
|
|
250
250
|
self,
|
|
@@ -12,12 +12,11 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
from typing import Optional
|
|
15
|
+
from typing import Optional, Union
|
|
16
16
|
|
|
17
17
|
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
|
|
18
18
|
from tokenizers.models import BPE
|
|
19
19
|
|
|
20
|
-
from ...tokenization_utils_base import generate_merges
|
|
21
20
|
from ...tokenization_utils_tokenizers import TokenizersBackend
|
|
22
21
|
from ...utils import logging
|
|
23
22
|
|
|
@@ -30,7 +29,7 @@ class GemmaTokenizer(TokenizersBackend):
|
|
|
30
29
|
"""
|
|
31
30
|
Construct a fast Gemma tokenizer (backed by HuggingFace's tokenizers library).
|
|
32
31
|
|
|
33
|
-
This tokenizer uses a
|
|
32
|
+
This tokenizer uses a BPE model with byte fallback, no prefix space, and a normalizer that replaces
|
|
34
33
|
spaces with "▁".
|
|
35
34
|
|
|
36
35
|
Args:
|
|
@@ -50,48 +49,37 @@ class GemmaTokenizer(TokenizersBackend):
|
|
|
50
49
|
Whether or not to add a `bos_token` at the start of sequences.
|
|
51
50
|
add_eos_token (`bool`, optional, defaults to False):
|
|
52
51
|
Whether or not to add an `eos_token` at the end of sequences.
|
|
53
|
-
vocab (`dict`, optional):
|
|
52
|
+
vocab (`str` or `dict[str, int]`, optional):
|
|
54
53
|
Custom vocabulary dict. If not provided, a minimal vocabulary is created using the special tokens.
|
|
55
54
|
"""
|
|
56
55
|
|
|
57
56
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
58
|
-
slow_tokenizer_class = None
|
|
59
57
|
padding_side = "left"
|
|
60
58
|
model_input_names = ["input_ids", "attention_mask"]
|
|
59
|
+
model = BPE
|
|
61
60
|
|
|
62
61
|
def __init__(
|
|
63
62
|
self,
|
|
63
|
+
vocab: Optional[Union[str, dict[str, int]]] = None,
|
|
64
|
+
merges: Optional[Union[str, list[str]]] = None,
|
|
64
65
|
unk_token: str = "<unk>",
|
|
65
66
|
bos_token: str = "<bos>",
|
|
66
67
|
eos_token: str = "<eos>",
|
|
67
68
|
pad_token: str = "<pad>",
|
|
68
69
|
mask_token: str = "<mask>",
|
|
69
|
-
add_bos_token: bool = True,
|
|
70
|
-
add_eos_token: bool = False,
|
|
71
|
-
vocab: Optional[dict] = None,
|
|
72
|
-
merges: Optional[list[tuple[str, str]]] = None,
|
|
73
70
|
**kwargs,
|
|
74
71
|
):
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
special_tokens = {str(pad_token), str(eos_token), str(bos_token), str(unk_token)}
|
|
79
|
-
|
|
80
|
-
if vocab is not None:
|
|
81
|
-
self._vocab = (
|
|
82
|
-
{token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
|
|
83
|
-
)
|
|
84
|
-
else:
|
|
85
|
-
self._vocab = {
|
|
72
|
+
if vocab is None:
|
|
73
|
+
vocab = {
|
|
86
74
|
str(pad_token): 0,
|
|
87
75
|
str(eos_token): 1,
|
|
88
76
|
str(bos_token): 2,
|
|
89
77
|
str(unk_token): 3,
|
|
90
78
|
str(mask_token): 4,
|
|
91
79
|
}
|
|
80
|
+
self._vocab = vocab
|
|
81
|
+
self._merges = merges or []
|
|
92
82
|
|
|
93
|
-
filtered_vocab = {t: i for t, i in self._vocab.items() if t not in special_tokens}
|
|
94
|
-
self._merges = merges if merges is not None else generate_merges(filtered_vocab)
|
|
95
83
|
self._tokenizer = Tokenizer(
|
|
96
84
|
BPE(
|
|
97
85
|
vocab=self._vocab,
|
|
@@ -108,17 +96,12 @@ class GemmaTokenizer(TokenizersBackend):
|
|
|
108
96
|
)
|
|
109
97
|
self._tokenizer.normalizer = normalizers.Replace(" ", "▁")
|
|
110
98
|
self._tokenizer.pre_tokenizer = pre_tokenizers.Split(" ", "merged_with_previous")
|
|
111
|
-
tokenizer_object = self._tokenizer
|
|
112
|
-
|
|
113
99
|
super().__init__(
|
|
114
|
-
tokenizer_object=tokenizer_object,
|
|
115
100
|
unk_token=unk_token,
|
|
116
101
|
bos_token=bos_token,
|
|
117
102
|
eos_token=eos_token,
|
|
118
103
|
pad_token=pad_token,
|
|
119
104
|
mask_token=mask_token,
|
|
120
|
-
add_bos_token=add_bos_token,
|
|
121
|
-
add_eos_token=add_eos_token,
|
|
122
105
|
**kwargs,
|
|
123
106
|
)
|
|
124
107
|
|
|
@@ -29,7 +29,7 @@ from ... import initialization as init
|
|
|
29
29
|
from ...activations import ACT2FN
|
|
30
30
|
from ...cache_utils import Cache, DynamicCache
|
|
31
31
|
from ...generation import GenerationMixin
|
|
32
|
-
from ...integrations import use_kernel_func_from_hub
|
|
32
|
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
|
33
33
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
34
34
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
35
35
|
from ...modeling_layers import (
|
|
@@ -42,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
42
42
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
43
43
|
from ...processing_utils import Unpack
|
|
44
44
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
45
|
-
from ...utils.generic import check_model_inputs
|
|
45
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
46
46
|
from .configuration_gemma2 import Gemma2Config
|
|
47
47
|
|
|
48
48
|
|
|
@@ -138,7 +138,7 @@ class Gemma2RotaryEmbedding(nn.Module):
|
|
|
138
138
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
139
139
|
|
|
140
140
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
141
|
-
with
|
|
141
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
142
142
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
143
143
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
144
144
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -229,6 +229,7 @@ def eager_attention_forward(
|
|
|
229
229
|
return attn_output, attn_weights
|
|
230
230
|
|
|
231
231
|
|
|
232
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
232
233
|
class Gemma2Attention(nn.Module):
|
|
233
234
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
234
235
|
|
|
@@ -255,7 +256,6 @@ class Gemma2Attention(nn.Module):
|
|
|
255
256
|
self.o_proj = nn.Linear(
|
|
256
257
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
257
258
|
)
|
|
258
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
259
259
|
self.attn_logit_softcapping = self.config.attn_logit_softcapping
|
|
260
260
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
261
261
|
|
|
@@ -34,6 +34,7 @@ from ...modeling_rope_utils import (
|
|
|
34
34
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
35
35
|
from ...processing_utils import Unpack
|
|
36
36
|
from ...utils import TransformersKwargs, logging
|
|
37
|
+
from ...utils.generic import maybe_autocast
|
|
37
38
|
from ..gemma.modeling_gemma import (
|
|
38
39
|
GemmaAttention,
|
|
39
40
|
GemmaForCausalLM,
|
|
@@ -252,7 +253,7 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
|
|
|
252
253
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
253
254
|
|
|
254
255
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
255
|
-
with
|
|
256
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
256
257
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
257
258
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
258
259
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -31,16 +31,15 @@ from ...activations import ACT2FN
|
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...configuration_utils import PreTrainedConfig
|
|
33
33
|
from ...generation import GenerationMixin
|
|
34
|
-
from ...integrations import use_kernel_func_from_hub
|
|
34
|
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
|
35
35
|
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
|
36
|
-
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
37
36
|
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
|
|
38
37
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
|
39
38
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
40
39
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
41
40
|
from ...processing_utils import Unpack
|
|
42
41
|
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
43
|
-
from ...utils.generic import check_model_inputs
|
|
42
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
44
43
|
from ..auto import AutoModel
|
|
45
44
|
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
|
|
46
45
|
|
|
@@ -215,7 +214,7 @@ class Gemma3RotaryEmbedding(nn.Module):
|
|
|
215
214
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
216
215
|
|
|
217
216
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
218
|
-
with
|
|
217
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
219
218
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
220
219
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
221
220
|
cos = emb.cos() * attention_scaling
|
|
@@ -306,6 +305,7 @@ def eager_attention_forward(
|
|
|
306
305
|
return attn_output, attn_weights
|
|
307
306
|
|
|
308
307
|
|
|
308
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
309
309
|
class Gemma3Attention(nn.Module):
|
|
310
310
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
311
311
|
|
|
@@ -332,7 +332,6 @@ class Gemma3Attention(nn.Module):
|
|
|
332
332
|
self.o_proj = nn.Linear(
|
|
333
333
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
334
334
|
)
|
|
335
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
336
335
|
self.attn_logit_softcapping = self.config.attn_logit_softcapping
|
|
337
336
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
338
337
|
self.is_sliding = self.layer_type == "sliding_attention"
|
|
@@ -347,7 +346,7 @@ class Gemma3Attention(nn.Module):
|
|
|
347
346
|
attention_mask: Optional[torch.Tensor] = None,
|
|
348
347
|
past_key_values: Optional[Cache] = None,
|
|
349
348
|
cache_position: Optional[torch.LongTensor] = None,
|
|
350
|
-
**kwargs: Unpack[
|
|
349
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
351
350
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
352
351
|
input_shape = hidden_states.shape[:-1]
|
|
353
352
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
@@ -409,23 +408,19 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
|
|
409
408
|
attention_mask: Optional[torch.Tensor] = None,
|
|
410
409
|
position_ids: Optional[torch.LongTensor] = None,
|
|
411
410
|
past_key_values: Optional[Cache] = None,
|
|
412
|
-
output_attentions: Optional[bool] = False,
|
|
413
|
-
use_cache: Optional[bool] = False,
|
|
414
411
|
cache_position: Optional[torch.LongTensor] = None,
|
|
415
|
-
**kwargs,
|
|
412
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
416
413
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
417
414
|
residual = hidden_states
|
|
418
415
|
|
|
419
416
|
hidden_states = self.input_layernorm(hidden_states)
|
|
420
417
|
|
|
421
|
-
hidden_states,
|
|
418
|
+
hidden_states, _ = self.self_attn(
|
|
422
419
|
hidden_states=hidden_states,
|
|
423
420
|
position_embeddings=position_embeddings,
|
|
424
421
|
attention_mask=attention_mask,
|
|
425
422
|
position_ids=position_ids,
|
|
426
423
|
past_key_values=past_key_values,
|
|
427
|
-
output_attentions=output_attentions,
|
|
428
|
-
use_cache=use_cache,
|
|
429
424
|
cache_position=cache_position,
|
|
430
425
|
**kwargs,
|
|
431
426
|
)
|
|
@@ -438,12 +433,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
|
|
438
433
|
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
|
439
434
|
hidden_states = residual + hidden_states
|
|
440
435
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
if output_attentions:
|
|
444
|
-
outputs += (self_attn_weights,)
|
|
445
|
-
|
|
446
|
-
return outputs
|
|
436
|
+
return hidden_states
|
|
447
437
|
|
|
448
438
|
|
|
449
439
|
@auto_docstring
|
|
@@ -527,30 +517,16 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
|
|
527
517
|
past_key_values: Optional[Cache] = None,
|
|
528
518
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
529
519
|
use_cache: Optional[bool] = None,
|
|
530
|
-
output_attentions: Optional[bool] = None,
|
|
531
|
-
output_hidden_states: Optional[bool] = None,
|
|
532
520
|
cache_position: Optional[torch.LongTensor] = None,
|
|
533
521
|
**kwargs: Unpack[TransformersKwargs],
|
|
534
522
|
) -> BaseModelOutputWithPast:
|
|
535
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
536
|
-
output_hidden_states = (
|
|
537
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
538
|
-
)
|
|
539
|
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
540
|
-
|
|
541
523
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
542
524
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
543
525
|
|
|
544
|
-
if self.gradient_checkpointing and self.training and use_cache:
|
|
545
|
-
logger.warning_once(
|
|
546
|
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
547
|
-
)
|
|
548
|
-
use_cache = False
|
|
549
|
-
|
|
550
526
|
if inputs_embeds is None:
|
|
551
527
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
552
528
|
|
|
553
|
-
if use_cache and past_key_values is None
|
|
529
|
+
if use_cache and past_key_values is None:
|
|
554
530
|
past_key_values = DynamicCache(config=self.config)
|
|
555
531
|
|
|
556
532
|
if cache_position is None:
|
|
@@ -591,41 +567,22 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
|
|
591
567
|
for layer_type in self.config.layer_types:
|
|
592
568
|
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
|
593
569
|
|
|
594
|
-
# decoder layers
|
|
595
|
-
all_hidden_states = () if output_hidden_states else None
|
|
596
|
-
all_self_attns = () if output_attentions else None
|
|
597
|
-
|
|
598
570
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
599
|
-
|
|
600
|
-
all_hidden_states += (hidden_states,)
|
|
601
|
-
|
|
602
|
-
layer_outputs = decoder_layer(
|
|
571
|
+
hidden_states = decoder_layer(
|
|
603
572
|
hidden_states,
|
|
604
573
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
|
605
574
|
position_embeddings=position_embeddings[decoder_layer.attention_type],
|
|
606
575
|
position_ids=position_ids,
|
|
607
576
|
past_key_values=past_key_values,
|
|
608
|
-
output_attentions=output_attentions,
|
|
609
|
-
use_cache=use_cache,
|
|
610
577
|
cache_position=cache_position,
|
|
611
578
|
**kwargs,
|
|
612
579
|
)
|
|
613
580
|
|
|
614
|
-
hidden_states = layer_outputs[0]
|
|
615
|
-
|
|
616
|
-
if output_attentions:
|
|
617
|
-
all_self_attns += (layer_outputs[1],)
|
|
618
|
-
|
|
619
581
|
hidden_states = self.norm(hidden_states)
|
|
620
582
|
|
|
621
|
-
if output_hidden_states:
|
|
622
|
-
all_hidden_states += (hidden_states,)
|
|
623
|
-
|
|
624
583
|
return BaseModelOutputWithPast(
|
|
625
584
|
last_hidden_state=hidden_states,
|
|
626
585
|
past_key_values=past_key_values,
|
|
627
|
-
hidden_states=all_hidden_states,
|
|
628
|
-
attentions=all_self_attns,
|
|
629
586
|
)
|
|
630
587
|
|
|
631
588
|
|
|
@@ -918,10 +875,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|
|
918
875
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
919
876
|
labels: Optional[torch.LongTensor] = None,
|
|
920
877
|
use_cache: Optional[bool] = None,
|
|
921
|
-
|
|
922
|
-
output_hidden_states: Optional[bool] = None,
|
|
923
|
-
return_dict: Optional[bool] = None,
|
|
924
|
-
**lm_kwargs,
|
|
878
|
+
**lm_kwargs: Unpack[TransformersKwargs],
|
|
925
879
|
) -> Union[tuple, Gemma3ModelOutputWithPast]:
|
|
926
880
|
r"""
|
|
927
881
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -953,12 +907,6 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|
|
953
907
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
954
908
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
955
909
|
|
|
956
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
957
|
-
output_hidden_states = (
|
|
958
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
959
|
-
)
|
|
960
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
961
|
-
|
|
962
910
|
# Replace image id with PAD if the image token if OOV, to avoid index-errors
|
|
963
911
|
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
|
964
912
|
special_image_mask = input_ids == self.config.image_token_id
|
|
@@ -1005,8 +953,6 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|
|
1005
953
|
past_key_values=past_key_values,
|
|
1006
954
|
inputs_embeds=inputs_embeds,
|
|
1007
955
|
use_cache=use_cache,
|
|
1008
|
-
output_attentions=output_attentions,
|
|
1009
|
-
output_hidden_states=output_hidden_states,
|
|
1010
956
|
return_dict=True,
|
|
1011
957
|
cache_position=cache_position,
|
|
1012
958
|
**lm_kwargs,
|
|
@@ -1014,7 +960,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|
|
1014
960
|
|
|
1015
961
|
return Gemma3ModelOutputWithPast(
|
|
1016
962
|
last_hidden_state=outputs.last_hidden_state,
|
|
1017
|
-
past_key_values=outputs.past_key_values
|
|
963
|
+
past_key_values=outputs.past_key_values,
|
|
1018
964
|
hidden_states=outputs.hidden_states,
|
|
1019
965
|
attentions=outputs.attentions,
|
|
1020
966
|
image_hidden_states=image_features if pixel_values is not None else None,
|
|
@@ -1053,6 +999,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1053
999
|
def get_image_features(self, pixel_values):
|
|
1054
1000
|
return self.model.get_image_features(pixel_values)
|
|
1055
1001
|
|
|
1002
|
+
@can_return_tuple
|
|
1056
1003
|
@auto_docstring
|
|
1057
1004
|
def forward(
|
|
1058
1005
|
self,
|
|
@@ -1066,11 +1013,8 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1066
1013
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
1067
1014
|
labels: Optional[torch.LongTensor] = None,
|
|
1068
1015
|
use_cache: Optional[bool] = None,
|
|
1069
|
-
output_attentions: Optional[bool] = None,
|
|
1070
|
-
output_hidden_states: Optional[bool] = None,
|
|
1071
|
-
return_dict: Optional[bool] = None,
|
|
1072
1016
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1073
|
-
**lm_kwargs,
|
|
1017
|
+
**lm_kwargs: Unpack[TransformersKwargs],
|
|
1074
1018
|
) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
|
|
1075
1019
|
r"""
|
|
1076
1020
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1116,13 +1060,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1116
1060
|
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
|
|
1117
1061
|
```
|
|
1118
1062
|
"""
|
|
1119
|
-
|
|
1120
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1121
|
-
output_hidden_states = (
|
|
1122
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1123
|
-
)
|
|
1124
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1125
|
-
|
|
1126
1063
|
outputs = self.model(
|
|
1127
1064
|
input_ids=input_ids,
|
|
1128
1065
|
pixel_values=pixel_values,
|
|
@@ -1133,9 +1070,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1133
1070
|
inputs_embeds=inputs_embeds,
|
|
1134
1071
|
use_cache=use_cache,
|
|
1135
1072
|
labels=labels,
|
|
1136
|
-
output_attentions=output_attentions,
|
|
1137
|
-
output_hidden_states=output_hidden_states,
|
|
1138
|
-
return_dict=return_dict,
|
|
1139
1073
|
cache_position=cache_position,
|
|
1140
1074
|
**lm_kwargs,
|
|
1141
1075
|
)
|
|
@@ -1167,10 +1101,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
|
1167
1101
|
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
|
1168
1102
|
loss = loss_fct(flat_logits, flat_labels)
|
|
1169
1103
|
|
|
1170
|
-
if not return_dict:
|
|
1171
|
-
output = (logits,) + outputs[1:]
|
|
1172
|
-
return (loss,) + output if loss is not None else output
|
|
1173
|
-
|
|
1174
1104
|
return Gemma3CausalLMOutputWithPast(
|
|
1175
1105
|
loss=loss,
|
|
1176
1106
|
logits=logits,
|