transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -21,6 +21,7 @@ import torch
|
|
|
21
21
|
import torch.nn as nn
|
|
22
22
|
from torch.nn import CrossEntropyLoss
|
|
23
23
|
|
|
24
|
+
from ... import initialization as init
|
|
24
25
|
from ...activations import ACT2FN
|
|
25
26
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
26
27
|
from ...generation import GenerationMixin
|
|
@@ -66,6 +67,7 @@ class NllbMoeSinusoidalPositionalEmbedding(nn.Module):
|
|
|
66
67
|
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
|
67
68
|
super().__init__()
|
|
68
69
|
self.offset = 2
|
|
70
|
+
self.num_positions = num_positions
|
|
69
71
|
self.embedding_dim = embedding_dim
|
|
70
72
|
self.padding_idx = padding_idx
|
|
71
73
|
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
|
|
@@ -665,6 +667,14 @@ class NllbMoePreTrainedModel(PreTrainedModel):
|
|
|
665
667
|
_supports_sdpa = False
|
|
666
668
|
_supports_flex_attn = False
|
|
667
669
|
|
|
670
|
+
def _init_weights(self, module):
|
|
671
|
+
super()._init_weights(module)
|
|
672
|
+
if isinstance(module, NllbMoeSinusoidalPositionalEmbedding):
|
|
673
|
+
emb_weights = module.get_embedding(
|
|
674
|
+
module.num_positions + module.offset, module.embedding_dim, module.padding_idx
|
|
675
|
+
)
|
|
676
|
+
init.copy_(module.weights, emb_weights)
|
|
677
|
+
|
|
668
678
|
|
|
669
679
|
class NllbMoeEncoder(NllbMoePreTrainedModel):
|
|
670
680
|
_can_record_outputs = {
|
|
@@ -290,7 +290,6 @@ class NougatImageProcessorFast(BaseImageProcessorFast):
|
|
|
290
290
|
processed_images_grouped[shape] = stacked_images
|
|
291
291
|
|
|
292
292
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
293
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
294
293
|
|
|
295
294
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
296
295
|
|
|
@@ -380,16 +380,16 @@ class NougatTokenizer(TokenizersBackend):
|
|
|
380
380
|
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
|
381
381
|
The token used for padding, for example when batching sequences of different lengths.
|
|
382
382
|
|
|
383
|
-
vocab (`dict`, *optional*):
|
|
383
|
+
vocab (`str`, `dict` or `list`, *optional*):
|
|
384
384
|
Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
|
|
385
385
|
|
|
386
|
-
merges (`list`, *optional*):
|
|
386
|
+
merges (`str` or `list`, *optional*):
|
|
387
387
|
Custom merges list. If not provided, merges are loaded from merges_file.
|
|
388
388
|
"""
|
|
389
389
|
|
|
390
390
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
391
391
|
model_input_names = ["input_ids", "attention_mask"]
|
|
392
|
-
|
|
392
|
+
model = BPE
|
|
393
393
|
|
|
394
394
|
def __init__(
|
|
395
395
|
self,
|
|
@@ -398,28 +398,22 @@ class NougatTokenizer(TokenizersBackend):
|
|
|
398
398
|
bos_token: str = "<s>",
|
|
399
399
|
eos_token: str = "</s>",
|
|
400
400
|
pad_token: str = "<pad>",
|
|
401
|
-
vocab: Optional[dict] = None,
|
|
402
|
-
merges: Optional[list] = None,
|
|
401
|
+
vocab: Optional[Union[str, dict, list]] = None,
|
|
402
|
+
merges: Optional[Union[str, list]] = None,
|
|
403
403
|
**kwargs,
|
|
404
404
|
):
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
else:
|
|
410
|
-
self._vocab = {
|
|
405
|
+
self._vocab = (
|
|
406
|
+
vocab
|
|
407
|
+
if vocab is not None
|
|
408
|
+
else {
|
|
411
409
|
str(bos_token): 0,
|
|
412
410
|
str(pad_token): 1,
|
|
413
411
|
str(eos_token): 2,
|
|
414
412
|
str(unk_token): 3,
|
|
415
413
|
"[START_REF]": 4,
|
|
416
414
|
}
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
self._merges = merges
|
|
420
|
-
else:
|
|
421
|
-
self._merges = []
|
|
422
|
-
|
|
415
|
+
)
|
|
416
|
+
self._merges = merges or []
|
|
423
417
|
self._tokenizer = Tokenizer(
|
|
424
418
|
BPE(
|
|
425
419
|
vocab=self._vocab,
|
|
@@ -447,27 +441,7 @@ class NougatTokenizer(TokenizersBackend):
|
|
|
447
441
|
)
|
|
448
442
|
self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True)
|
|
449
443
|
|
|
450
|
-
# Set up post processor with bos and eos tokens
|
|
451
|
-
bos_token_id = self._vocab.get(str(bos_token), 0)
|
|
452
|
-
eos_token_id = self._vocab.get(str(eos_token), 2)
|
|
453
|
-
pad_token_id = self._vocab.get(str(pad_token), 1)
|
|
454
|
-
self._tokenizer.post_processor = processors.TemplateProcessing(
|
|
455
|
-
single=f"{bos_token}:0 $A:0 {eos_token}:0",
|
|
456
|
-
pair="$A:0 $B:1",
|
|
457
|
-
special_tokens=[
|
|
458
|
-
(str(eos_token), eos_token_id),
|
|
459
|
-
(str(bos_token), bos_token_id),
|
|
460
|
-
],
|
|
461
|
-
)
|
|
462
|
-
|
|
463
|
-
# Enable truncation and padding
|
|
464
|
-
self._tokenizer.enable_truncation(max_length=4096)
|
|
465
|
-
self._tokenizer.enable_padding(length=4096, pad_id=pad_token_id, pad_token=str(pad_token))
|
|
466
|
-
|
|
467
|
-
tokenizer_object = self._tokenizer
|
|
468
|
-
|
|
469
444
|
super().__init__(
|
|
470
|
-
tokenizer_object=tokenizer_object,
|
|
471
445
|
errors=errors,
|
|
472
446
|
unk_token=unk_token,
|
|
473
447
|
bos_token=bos_token,
|
|
@@ -475,45 +449,18 @@ class NougatTokenizer(TokenizersBackend):
|
|
|
475
449
|
pad_token=pad_token,
|
|
476
450
|
**kwargs,
|
|
477
451
|
)
|
|
478
|
-
|
|
479
|
-
def _post_init(self):
|
|
480
|
-
"""Post-initialization to ensure tokenizer settings are applied correctly."""
|
|
481
|
-
# Re-apply settings to ensure they're correct after loading from pretrained
|
|
482
|
-
self._tokenizer.normalizer = normalizers.NFKC()
|
|
483
|
-
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
|
484
|
-
[
|
|
485
|
-
pre_tokenizers.Split(pattern="SPL1T-TH1S-Pl3A5E", behavior="removed", invert=False),
|
|
486
|
-
pre_tokenizers.Digits(individual_digits=True),
|
|
487
|
-
pre_tokenizers.Split(
|
|
488
|
-
pattern=r"[\(\)\[\]\{\}]|([!\"#\$%\&'\*\+,\-\./:;<=>\?\\\^_`\|\~])\1*",
|
|
489
|
-
behavior="isolated",
|
|
490
|
-
invert=False,
|
|
491
|
-
),
|
|
492
|
-
pre_tokenizers.Split(pattern="\n", behavior="isolated", invert=False),
|
|
493
|
-
pre_tokenizers.ByteLevel(add_prefix_space=False, trim_offsets=True, use_regex=True),
|
|
494
|
-
]
|
|
495
|
-
)
|
|
496
|
-
self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True)
|
|
497
|
-
|
|
498
|
-
# Set up post processor with bos and eos tokens
|
|
499
|
-
bos_token_id = self.bos_token_id if self.bos_token_id is not None else 0
|
|
500
|
-
eos_token_id = self.eos_token_id if self.eos_token_id is not None else 2
|
|
501
|
-
pad_token_id = self.pad_token_id if self.pad_token_id is not None else 1
|
|
502
452
|
self._tokenizer.post_processor = processors.TemplateProcessing(
|
|
503
|
-
single=f"{
|
|
453
|
+
single=f"{bos_token}:0 $A:0 {eos_token}:0",
|
|
504
454
|
pair="$A:0 $B:1",
|
|
505
455
|
special_tokens=[
|
|
506
|
-
(str(
|
|
507
|
-
(str(
|
|
456
|
+
(str(eos_token), self.eos_token_id),
|
|
457
|
+
(str(bos_token), self.bos_token_id),
|
|
508
458
|
],
|
|
509
459
|
)
|
|
510
460
|
|
|
511
461
|
# Enable truncation and padding
|
|
512
462
|
self._tokenizer.enable_truncation(max_length=4096)
|
|
513
|
-
self._tokenizer.enable_padding(length=4096, pad_id=pad_token_id, pad_token=str(
|
|
514
|
-
|
|
515
|
-
# Call parent to handle AddedToken properties
|
|
516
|
-
super()._post_init()
|
|
463
|
+
self._tokenizer.enable_padding(length=4096, pad_id=self.pad_token_id, pad_token=str(pad_token))
|
|
517
464
|
|
|
518
465
|
def remove_hallucinated_references(self, text: str) -> str:
|
|
519
466
|
"""
|
|
@@ -21,6 +21,7 @@ import torch
|
|
|
21
21
|
from torch import nn
|
|
22
22
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
23
23
|
|
|
24
|
+
from ... import initialization as init
|
|
24
25
|
from ...activations import ACT2FN
|
|
25
26
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
26
27
|
from ...modeling_outputs import (
|
|
@@ -413,6 +414,12 @@ class NystromformerPreTrainedModel(PreTrainedModel):
|
|
|
413
414
|
base_model_prefix = "nystromformer"
|
|
414
415
|
supports_gradient_checkpointing = True
|
|
415
416
|
|
|
417
|
+
def _init_weights(self, module):
|
|
418
|
+
super()._init_weights(module)
|
|
419
|
+
if isinstance(module, NystromformerEmbeddings):
|
|
420
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)) + 2)
|
|
421
|
+
init.zeros_(module.token_type_ids)
|
|
422
|
+
|
|
416
423
|
|
|
417
424
|
@auto_docstring
|
|
418
425
|
class NystromformerModel(NystromformerPreTrainedModel):
|
|
@@ -443,6 +450,7 @@ class NystromformerModel(NystromformerPreTrainedModel):
|
|
|
443
450
|
output_attentions: Optional[bool] = None,
|
|
444
451
|
output_hidden_states: Optional[bool] = None,
|
|
445
452
|
return_dict: Optional[bool] = None,
|
|
453
|
+
**kwargs,
|
|
446
454
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
|
447
455
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
448
456
|
output_hidden_states = (
|
|
@@ -539,6 +547,7 @@ class NystromformerForMaskedLM(NystromformerPreTrainedModel):
|
|
|
539
547
|
output_attentions: Optional[bool] = None,
|
|
540
548
|
output_hidden_states: Optional[bool] = None,
|
|
541
549
|
return_dict: Optional[bool] = None,
|
|
550
|
+
**kwargs,
|
|
542
551
|
) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
|
|
543
552
|
r"""
|
|
544
553
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -628,6 +637,7 @@ class NystromformerForSequenceClassification(NystromformerPreTrainedModel):
|
|
|
628
637
|
output_attentions: Optional[bool] = None,
|
|
629
638
|
output_hidden_states: Optional[bool] = None,
|
|
630
639
|
return_dict: Optional[bool] = None,
|
|
640
|
+
**kwargs,
|
|
631
641
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
|
|
632
642
|
r"""
|
|
633
643
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -709,6 +719,7 @@ class NystromformerForMultipleChoice(NystromformerPreTrainedModel):
|
|
|
709
719
|
output_attentions: Optional[bool] = None,
|
|
710
720
|
output_hidden_states: Optional[bool] = None,
|
|
711
721
|
return_dict: Optional[bool] = None,
|
|
722
|
+
**kwargs,
|
|
712
723
|
) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
|
713
724
|
r"""
|
|
714
725
|
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
|
@@ -814,6 +825,7 @@ class NystromformerForTokenClassification(NystromformerPreTrainedModel):
|
|
|
814
825
|
output_attentions: Optional[bool] = None,
|
|
815
826
|
output_hidden_states: Optional[bool] = None,
|
|
816
827
|
return_dict: Optional[bool] = None,
|
|
828
|
+
**kwargs,
|
|
817
829
|
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
|
818
830
|
r"""
|
|
819
831
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -881,6 +893,7 @@ class NystromformerForQuestionAnswering(NystromformerPreTrainedModel):
|
|
|
881
893
|
output_attentions: Optional[bool] = None,
|
|
882
894
|
output_hidden_states: Optional[bool] = None,
|
|
883
895
|
return_dict: Optional[bool] = None,
|
|
896
|
+
**kwargs,
|
|
884
897
|
) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
|
885
898
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
886
899
|
|
|
@@ -34,6 +34,7 @@ import torch.nn.functional as F
|
|
|
34
34
|
from ...activations import ACT2FN
|
|
35
35
|
from ...cache_utils import Cache, DynamicCache
|
|
36
36
|
from ...generation import GenerationMixin
|
|
37
|
+
from ...integrations import use_kernelized_func
|
|
37
38
|
from ...masking_utils import create_causal_mask
|
|
38
39
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
39
40
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -41,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
41
42
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
42
43
|
from ...processing_utils import Unpack
|
|
43
44
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
44
|
-
from ...utils.generic import check_model_inputs
|
|
45
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
45
46
|
from .configuration_olmo import OlmoConfig
|
|
46
47
|
|
|
47
48
|
|
|
@@ -92,7 +93,7 @@ class OlmoRotaryEmbedding(nn.Module):
|
|
|
92
93
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
93
94
|
|
|
94
95
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
95
|
-
self.original_inv_freq =
|
|
96
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
96
97
|
|
|
97
98
|
@staticmethod
|
|
98
99
|
def compute_default_rope_parameters(
|
|
@@ -131,7 +132,7 @@ class OlmoRotaryEmbedding(nn.Module):
|
|
|
131
132
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
132
133
|
|
|
133
134
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
134
|
-
with
|
|
135
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
135
136
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
136
137
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
137
138
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -212,6 +213,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
212
213
|
return q_embed.to(q_type), k_embed.to(k_type)
|
|
213
214
|
|
|
214
215
|
|
|
216
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
215
217
|
class OlmoAttention(nn.Module):
|
|
216
218
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
217
219
|
|
|
@@ -237,7 +239,6 @@ class OlmoAttention(nn.Module):
|
|
|
237
239
|
self.o_proj = nn.Linear(
|
|
238
240
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
239
241
|
)
|
|
240
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
241
242
|
|
|
242
243
|
def forward(
|
|
243
244
|
self,
|
|
@@ -246,7 +247,6 @@ class OlmoAttention(nn.Module):
|
|
|
246
247
|
attention_mask: Optional[torch.Tensor],
|
|
247
248
|
past_key_values: Optional[Cache] = None,
|
|
248
249
|
cache_position: Optional[torch.LongTensor] = None,
|
|
249
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
250
250
|
**kwargs,
|
|
251
251
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
252
252
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -29,6 +29,7 @@ from ...cache_utils import Cache
|
|
|
29
29
|
from ...modeling_rope_utils import dynamic_rope_update
|
|
30
30
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
31
31
|
from ...utils import logging
|
|
32
|
+
from ...utils.generic import maybe_autocast
|
|
32
33
|
from ..llama.modeling_llama import (
|
|
33
34
|
LlamaAttention,
|
|
34
35
|
LlamaDecoderLayer,
|
|
@@ -77,7 +78,7 @@ class OlmoRotaryEmbedding(LlamaRotaryEmbedding):
|
|
|
77
78
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
78
79
|
|
|
79
80
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
80
|
-
with
|
|
81
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
81
82
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
82
83
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
83
84
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -121,7 +122,6 @@ class OlmoAttention(LlamaAttention):
|
|
|
121
122
|
attention_mask: Optional[torch.Tensor],
|
|
122
123
|
past_key_values: Optional[Cache] = None,
|
|
123
124
|
cache_position: Optional[torch.LongTensor] = None,
|
|
124
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
125
125
|
**kwargs,
|
|
126
126
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
127
127
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -35,7 +35,7 @@ from transformers.utils.generic import TransformersKwargs
|
|
|
35
35
|
from ...activations import ACT2FN
|
|
36
36
|
from ...cache_utils import Cache, DynamicCache
|
|
37
37
|
from ...generation import GenerationMixin
|
|
38
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
38
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
39
39
|
from ...masking_utils import create_causal_mask
|
|
40
40
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
41
41
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -43,7 +43,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
43
43
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
44
44
|
from ...processing_utils import Unpack
|
|
45
45
|
from ...utils import auto_docstring, can_return_tuple
|
|
46
|
-
from ...utils.generic import check_model_inputs
|
|
46
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
47
47
|
from .configuration_olmo2 import Olmo2Config
|
|
48
48
|
|
|
49
49
|
|
|
@@ -85,7 +85,7 @@ class Olmo2RotaryEmbedding(nn.Module):
|
|
|
85
85
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
86
86
|
|
|
87
87
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
88
|
-
self.original_inv_freq =
|
|
88
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
89
89
|
|
|
90
90
|
@staticmethod
|
|
91
91
|
def compute_default_rope_parameters(
|
|
@@ -124,7 +124,7 @@ class Olmo2RotaryEmbedding(nn.Module):
|
|
|
124
124
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
125
125
|
|
|
126
126
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
127
|
-
with
|
|
127
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
128
128
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
129
129
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
130
130
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -205,6 +205,7 @@ def rotate_half(x):
|
|
|
205
205
|
return torch.cat((-x2, x1), dim=-1)
|
|
206
206
|
|
|
207
207
|
|
|
208
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
208
209
|
class Olmo2Attention(nn.Module):
|
|
209
210
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
210
211
|
|
|
@@ -230,7 +231,6 @@ class Olmo2Attention(nn.Module):
|
|
|
230
231
|
self.o_proj = nn.Linear(
|
|
231
232
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
232
233
|
)
|
|
233
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
234
234
|
self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
|
|
235
235
|
self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
|
|
236
236
|
|
|
@@ -241,7 +241,6 @@ class Olmo2Attention(nn.Module):
|
|
|
241
241
|
attention_mask: Optional[torch.Tensor],
|
|
242
242
|
past_key_values: Optional[Cache] = None,
|
|
243
243
|
cache_position: Optional[torch.LongTensor] = None,
|
|
244
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
245
244
|
**kwargs: Unpack[TransformersKwargs],
|
|
246
245
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
247
246
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -219,7 +219,6 @@ class Olmo2Attention(OlmoAttention):
|
|
|
219
219
|
attention_mask: Optional[torch.Tensor],
|
|
220
220
|
past_key_values: Optional[Cache] = None,
|
|
221
221
|
cache_position: Optional[torch.LongTensor] = None,
|
|
222
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
223
222
|
**kwargs: Unpack[TransformersKwargs],
|
|
224
223
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
225
224
|
input_shape = hidden_states.shape[:-1]
|
|
@@ -30,7 +30,7 @@ from transformers.utils.generic import TransformersKwargs
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
33
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
35
35
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
36
36
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
38
38
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
40
|
from ...utils import auto_docstring, can_return_tuple
|
|
41
|
-
from ...utils.generic import check_model_inputs
|
|
41
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
42
42
|
from .configuration_olmo3 import Olmo3Config
|
|
43
43
|
|
|
44
44
|
|
|
@@ -136,6 +136,7 @@ def rotate_half(x):
|
|
|
136
136
|
return torch.cat((-x2, x1), dim=-1)
|
|
137
137
|
|
|
138
138
|
|
|
139
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
139
140
|
class Olmo3Attention(nn.Module):
|
|
140
141
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
141
142
|
|
|
@@ -161,7 +162,6 @@ class Olmo3Attention(nn.Module):
|
|
|
161
162
|
self.o_proj = nn.Linear(
|
|
162
163
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
163
164
|
)
|
|
164
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
165
165
|
self.q_norm = Olmo3RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
|
|
166
166
|
self.k_norm = Olmo3RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
|
|
167
167
|
assert config.layer_types is not None
|
|
@@ -293,7 +293,7 @@ class Olmo3RotaryEmbedding(nn.Module):
|
|
|
293
293
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
294
294
|
|
|
295
295
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
296
|
-
self.original_inv_freq =
|
|
296
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
297
297
|
|
|
298
298
|
@staticmethod
|
|
299
299
|
def compute_default_rope_parameters(
|
|
@@ -332,7 +332,7 @@ class Olmo3RotaryEmbedding(nn.Module):
|
|
|
332
332
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
333
333
|
|
|
334
334
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
335
|
-
with
|
|
335
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
336
336
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
337
337
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
338
338
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -27,15 +27,20 @@ from ... import initialization as init
|
|
|
27
27
|
from ...activations import ACT2FN
|
|
28
28
|
from ...cache_utils import Cache, DynamicCache
|
|
29
29
|
from ...generation import GenerationMixin
|
|
30
|
-
from ...integrations import
|
|
30
|
+
from ...integrations import (
|
|
31
|
+
use_experts_implementation,
|
|
32
|
+
use_kernel_forward_from_hub,
|
|
33
|
+
use_kernel_func_from_hub,
|
|
34
|
+
use_kernelized_func,
|
|
35
|
+
)
|
|
31
36
|
from ...masking_utils import create_causal_mask
|
|
32
37
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
33
38
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
34
39
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
35
40
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
36
41
|
from ...processing_utils import Unpack
|
|
37
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
38
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
42
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
43
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
39
44
|
from .configuration_olmoe import OlmoeConfig
|
|
40
45
|
|
|
41
46
|
|
|
@@ -77,7 +82,7 @@ class OlmoeRotaryEmbedding(nn.Module):
|
|
|
77
82
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
78
83
|
|
|
79
84
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
80
|
-
self.original_inv_freq =
|
|
85
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
81
86
|
|
|
82
87
|
@staticmethod
|
|
83
88
|
def compute_default_rope_parameters(
|
|
@@ -116,7 +121,7 @@ class OlmoeRotaryEmbedding(nn.Module):
|
|
|
116
121
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
117
122
|
|
|
118
123
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
119
|
-
with
|
|
124
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
120
125
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
121
126
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
122
127
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -214,6 +219,7 @@ def eager_attention_forward(
|
|
|
214
219
|
return attn_output, attn_weights
|
|
215
220
|
|
|
216
221
|
|
|
222
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
217
223
|
class OlmoeAttention(nn.Module):
|
|
218
224
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
219
225
|
|
|
@@ -239,7 +245,6 @@ class OlmoeAttention(nn.Module):
|
|
|
239
245
|
self.o_proj = nn.Linear(
|
|
240
246
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
241
247
|
)
|
|
242
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
243
248
|
self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
244
249
|
self.k_norm = OlmoeRMSNorm(
|
|
245
250
|
(config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps
|
|
@@ -298,6 +303,7 @@ class OlmoeAttention(nn.Module):
|
|
|
298
303
|
return attn_output, attn_weights
|
|
299
304
|
|
|
300
305
|
|
|
306
|
+
@use_experts_implementation
|
|
301
307
|
class OlmoeExperts(nn.Module):
|
|
302
308
|
"""Collection of expert weights stored as 3D tensors."""
|
|
303
309
|
|
|
@@ -431,7 +437,9 @@ class OlmoePreTrainedModel(PreTrainedModel):
|
|
|
431
437
|
"hidden_states": OlmoeDecoderLayer,
|
|
432
438
|
"attentions": OlmoeAttention,
|
|
433
439
|
}
|
|
434
|
-
_can_compile_fullgraph =
|
|
440
|
+
_can_compile_fullgraph = (
|
|
441
|
+
is_grouped_mm_available()
|
|
442
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
435
443
|
_supports_attention_backend = True
|
|
436
444
|
|
|
437
445
|
@torch.no_grad()
|
|
@@ -24,7 +24,7 @@ from ...masking_utils import create_causal_mask
|
|
|
24
24
|
from ...modeling_outputs import MoeModelOutputWithPast
|
|
25
25
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
26
26
|
from ...processing_utils import Unpack
|
|
27
|
-
from ...utils import TransformersKwargs, auto_docstring, logging
|
|
27
|
+
from ...utils import TransformersKwargs, auto_docstring, is_grouped_mm_available, logging
|
|
28
28
|
from ...utils.generic import OutputRecorder
|
|
29
29
|
from ..gemma.modeling_gemma import GemmaMLP
|
|
30
30
|
from ..llama.modeling_llama import (
|
|
@@ -165,7 +165,9 @@ class OlmoePreTrainedModel(PreTrainedModel):
|
|
|
165
165
|
"hidden_states": OlmoeDecoderLayer,
|
|
166
166
|
"attentions": OlmoeAttention,
|
|
167
167
|
}
|
|
168
|
-
_can_compile_fullgraph =
|
|
168
|
+
_can_compile_fullgraph = (
|
|
169
|
+
is_grouped_mm_available()
|
|
170
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
169
171
|
_supports_attention_backend = True
|
|
170
172
|
|
|
171
173
|
@torch.no_grad()
|
|
@@ -36,7 +36,7 @@ class OmDetTurboConfig(PreTrainedConfig):
|
|
|
36
36
|
Args:
|
|
37
37
|
text_config (`PreTrainedConfig`, *optional*):
|
|
38
38
|
The configuration of the text backbone.
|
|
39
|
-
backbone_config (`PreTrainedConfig`, *optional
|
|
39
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `SwinConfig()`):
|
|
40
40
|
The configuration of the vision backbone.
|
|
41
41
|
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
|
42
42
|
Whether to use the timm for the vision backbone.
|
|
@@ -68,7 +68,7 @@ class OmDetTurboConfig(PreTrainedConfig):
|
|
|
68
68
|
class_embed_dim (`int`, *optional*, defaults to 512):
|
|
69
69
|
The dimension of the classes embeddings.
|
|
70
70
|
class_distance_type (`str`, *optional*, defaults to `"cosine"`):
|
|
71
|
-
The type of
|
|
71
|
+
The type of distance to compare predicted classes to projected classes embeddings.
|
|
72
72
|
Can be `"cosine"` or `"dot"`.
|
|
73
73
|
num_queries (`int`, *optional*, defaults to 900):
|
|
74
74
|
The number of queries.
|
|
@@ -1022,6 +1022,10 @@ class OmDetTurboPreTrainedModel(PreTrainedModel):
|
|
|
1022
1022
|
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
|
|
1023
1023
|
init.ones_(module.weight)
|
|
1024
1024
|
init.zeros_(module.bias)
|
|
1025
|
+
if getattr(module, "running_mean", None) is not None:
|
|
1026
|
+
init.zeros_(module.running_mean)
|
|
1027
|
+
init.ones_(module.running_var)
|
|
1028
|
+
init.zeros_(module.num_batches_tracked)
|
|
1025
1029
|
|
|
1026
1030
|
def _set_gradient_checkpointing(self, module, value=False):
|
|
1027
1031
|
if isinstance(module, OmDetTurboDecoder):
|
|
@@ -1316,6 +1320,7 @@ class OmDetTurboDecoder(OmDetTurboPreTrainedModel):
|
|
|
1316
1320
|
output_attentions=None,
|
|
1317
1321
|
output_hidden_states=None,
|
|
1318
1322
|
return_dict=None,
|
|
1323
|
+
**kwargs,
|
|
1319
1324
|
):
|
|
1320
1325
|
"""
|
|
1321
1326
|
Args:
|
|
@@ -1505,6 +1510,7 @@ class OmDetTurboForObjectDetection(OmDetTurboPreTrainedModel):
|
|
|
1505
1510
|
output_attentions: Optional[bool] = None,
|
|
1506
1511
|
output_hidden_states: Optional[bool] = None,
|
|
1507
1512
|
return_dict: Optional[bool] = None,
|
|
1513
|
+
**kwargs,
|
|
1508
1514
|
) -> Union[tuple[torch.FloatTensor], OmDetTurboObjectDetectionOutput]:
|
|
1509
1515
|
r"""
|
|
1510
1516
|
classes_input_ids (`torch.LongTensor` of shape `(total_classes (>= batch_size), sequence_length)`):
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""OneFormer model configuration"""
|
|
16
16
|
|
|
17
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
from ...configuration_utils import PreTrainedConfig
|
|
20
20
|
from ...utils import logging
|
|
@@ -37,7 +37,7 @@ class OneFormerConfig(PreTrainedConfig):
|
|
|
37
37
|
documentation from [`PreTrainedConfig`] for more information.
|
|
38
38
|
|
|
39
39
|
Args:
|
|
40
|
-
backbone_config (`PreTrainedConfig`, *optional*, defaults to `SwinConfig`):
|
|
40
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `SwinConfig()`):
|
|
41
41
|
The configuration of the backbone model.
|
|
42
42
|
backbone (`str`, *optional*):
|
|
43
43
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -151,7 +151,7 @@ class OneFormerConfig(PreTrainedConfig):
|
|
|
151
151
|
|
|
152
152
|
def __init__(
|
|
153
153
|
self,
|
|
154
|
-
backbone_config: Optional[dict] = None,
|
|
154
|
+
backbone_config: Optional[Union[dict, PreTrainedConfig]] = None,
|
|
155
155
|
backbone: Optional[str] = None,
|
|
156
156
|
use_pretrained_backbone: bool = False,
|
|
157
157
|
use_timm_backbone: bool = False,
|