transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -32,7 +32,13 @@ from torch import nn
|
|
|
32
32
|
from ... import initialization as init
|
|
33
33
|
from ...activations import ACT2FN
|
|
34
34
|
from ...generation import GenerationMixin
|
|
35
|
-
from ...integrations import
|
|
35
|
+
from ...integrations import (
|
|
36
|
+
lazy_load_kernel,
|
|
37
|
+
use_experts_implementation,
|
|
38
|
+
use_kernel_forward_from_hub,
|
|
39
|
+
use_kernel_func_from_hub,
|
|
40
|
+
use_kernelized_func,
|
|
41
|
+
)
|
|
36
42
|
from ...masking_utils import create_causal_mask
|
|
37
43
|
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
|
|
38
44
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
@@ -40,22 +46,9 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
|
40
46
|
from ...processing_utils import Unpack
|
|
41
47
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
42
48
|
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
43
|
-
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
|
|
44
49
|
from .configuration_jamba import JambaConfig
|
|
45
50
|
|
|
46
51
|
|
|
47
|
-
if is_mamba_ssm_available():
|
|
48
|
-
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
|
|
49
|
-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
|
50
|
-
else:
|
|
51
|
-
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
|
52
|
-
|
|
53
|
-
if is_causal_conv1d_available():
|
|
54
|
-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
55
|
-
else:
|
|
56
|
-
causal_conv1d_update, causal_conv1d_fn = None, None
|
|
57
|
-
|
|
58
|
-
|
|
59
52
|
logger = logging.get_logger(__name__)
|
|
60
53
|
|
|
61
54
|
|
|
@@ -248,6 +241,7 @@ def eager_attention_forward(
|
|
|
248
241
|
return attn_output, attn_weights
|
|
249
242
|
|
|
250
243
|
|
|
244
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
251
245
|
class JambaAttention(nn.Module):
|
|
252
246
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
253
247
|
|
|
@@ -264,7 +258,6 @@ class JambaAttention(nn.Module):
|
|
|
264
258
|
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
|
265
259
|
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
|
266
260
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
|
267
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
268
261
|
|
|
269
262
|
def forward(
|
|
270
263
|
self,
|
|
@@ -306,11 +299,6 @@ class JambaAttention(nn.Module):
|
|
|
306
299
|
return attn_output, attn_weights
|
|
307
300
|
|
|
308
301
|
|
|
309
|
-
is_fast_path_available = all(
|
|
310
|
-
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
|
|
314
302
|
class JambaMambaMixer(nn.Module):
|
|
315
303
|
"""
|
|
316
304
|
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
|
|
@@ -364,6 +352,22 @@ class JambaMambaMixer(nn.Module):
|
|
|
364
352
|
self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
|
|
365
353
|
self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
|
|
366
354
|
|
|
355
|
+
global causal_conv1d_update, causal_conv1d_fn
|
|
356
|
+
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
357
|
+
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
|
|
358
|
+
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
|
|
359
|
+
|
|
360
|
+
global selective_state_update, mamba_inner_fn, selective_scan_fn
|
|
361
|
+
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
|
362
|
+
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
|
|
363
|
+
mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)
|
|
364
|
+
selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None)
|
|
365
|
+
|
|
366
|
+
global is_fast_path_available
|
|
367
|
+
is_fast_path_available = all(
|
|
368
|
+
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
|
369
|
+
)
|
|
370
|
+
|
|
367
371
|
if not is_fast_path_available:
|
|
368
372
|
logger.warning_once(
|
|
369
373
|
"The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
|
|
@@ -594,6 +598,7 @@ class JambaMLP(nn.Module):
|
|
|
594
598
|
return down_proj
|
|
595
599
|
|
|
596
600
|
|
|
601
|
+
@use_experts_implementation
|
|
597
602
|
class JambaExperts(nn.Module):
|
|
598
603
|
"""Collection of expert weights stored as 3D tensors."""
|
|
599
604
|
|
|
@@ -25,6 +25,7 @@ from torch import nn
|
|
|
25
25
|
|
|
26
26
|
from ... import initialization as init
|
|
27
27
|
from ...activations import ACT2FN
|
|
28
|
+
from ...integrations import lazy_load_kernel
|
|
28
29
|
from ...masking_utils import create_causal_mask
|
|
29
30
|
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
|
|
30
31
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
@@ -32,29 +33,12 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
|
32
33
|
from ...processing_utils import Unpack
|
|
33
34
|
from ...utils import TransformersKwargs, auto_docstring, logging
|
|
34
35
|
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
35
|
-
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
|
|
36
36
|
from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm, eager_attention_forward
|
|
37
37
|
from ..mistral.modeling_mistral import MistralMLP
|
|
38
38
|
from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM
|
|
39
39
|
from .configuration_jamba import JambaConfig
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
if is_mamba_ssm_available():
|
|
43
|
-
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
|
|
44
|
-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
|
45
|
-
else:
|
|
46
|
-
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
|
47
|
-
|
|
48
|
-
if is_causal_conv1d_available():
|
|
49
|
-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
50
|
-
else:
|
|
51
|
-
causal_conv1d_update, causal_conv1d_fn = None, None
|
|
52
|
-
|
|
53
|
-
is_fast_path_available = all(
|
|
54
|
-
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
|
|
58
42
|
logger = logging.get_logger(__name__)
|
|
59
43
|
|
|
60
44
|
|
|
@@ -258,6 +242,22 @@ class JambaMambaMixer(nn.Module):
|
|
|
258
242
|
self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
|
|
259
243
|
self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
|
|
260
244
|
|
|
245
|
+
global causal_conv1d_update, causal_conv1d_fn
|
|
246
|
+
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
247
|
+
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
|
|
248
|
+
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
|
|
249
|
+
|
|
250
|
+
global selective_state_update, mamba_inner_fn, selective_scan_fn
|
|
251
|
+
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
|
252
|
+
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
|
|
253
|
+
mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)
|
|
254
|
+
selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None)
|
|
255
|
+
|
|
256
|
+
global is_fast_path_available
|
|
257
|
+
is_fast_path_available = all(
|
|
258
|
+
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
|
259
|
+
)
|
|
260
|
+
|
|
261
261
|
if not is_fast_path_available:
|
|
262
262
|
logger.warning_once(
|
|
263
263
|
"The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
|
|
@@ -180,7 +180,6 @@ class JanusImageProcessorFast(BaseImageProcessorFast):
|
|
|
180
180
|
processed_images_grouped[shape] = stacked_images
|
|
181
181
|
|
|
182
182
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
183
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
184
183
|
|
|
185
184
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
186
185
|
|
|
@@ -28,6 +28,7 @@ import torch
|
|
|
28
28
|
import torch.nn.functional as F
|
|
29
29
|
from torch import nn
|
|
30
30
|
|
|
31
|
+
from ... import initialization as init
|
|
31
32
|
from ...activations import ACT2FN
|
|
32
33
|
from ...cache_utils import Cache
|
|
33
34
|
from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList
|
|
@@ -58,6 +59,11 @@ class JanusPreTrainedModel(PreTrainedModel):
|
|
|
58
59
|
|
|
59
60
|
_can_compile_fullgraph = True
|
|
60
61
|
|
|
62
|
+
def _init_weights(self, module):
|
|
63
|
+
super()._init_weights(module)
|
|
64
|
+
if isinstance(module, JanusVisionEmbeddings):
|
|
65
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
66
|
+
|
|
61
67
|
|
|
62
68
|
@dataclass
|
|
63
69
|
@auto_docstring(
|
|
@@ -973,8 +979,6 @@ class JanusVQVAE(JanusPreTrainedModel):
|
|
|
973
979
|
self.eval() # Janus's VQ model is frozen
|
|
974
980
|
self.decoder = JanusVQVAEDecoder(config)
|
|
975
981
|
self.gradient_checkpointing = False
|
|
976
|
-
|
|
977
|
-
# Initialize the VQVAE model.
|
|
978
982
|
self.post_init()
|
|
979
983
|
|
|
980
984
|
def encode(self, pixel_values: torch.LongTensor):
|
|
@@ -1007,6 +1011,7 @@ class JanusVQVAE(JanusPreTrainedModel):
|
|
|
1007
1011
|
def forward(
|
|
1008
1012
|
self,
|
|
1009
1013
|
pixel_values: torch.FloatTensor,
|
|
1014
|
+
**kwargs,
|
|
1010
1015
|
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
1011
1016
|
batch_size = pixel_values.shape[0]
|
|
1012
1017
|
quant, embedding_loss, indices = self.encode(pixel_values)
|
|
@@ -1125,7 +1130,7 @@ class JanusModel(JanusPreTrainedModel):
|
|
|
1125
1130
|
use_cache: Optional[bool] = None,
|
|
1126
1131
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1127
1132
|
**kwargs,
|
|
1128
|
-
):
|
|
1133
|
+
) -> JanusBaseModelOutputWithPast:
|
|
1129
1134
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
1130
1135
|
raise ValueError(
|
|
1131
1136
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
|
@@ -1202,7 +1207,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
|
|
1202
1207
|
use_cache: Optional[bool] = None,
|
|
1203
1208
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1204
1209
|
**kwargs: Unpack[TransformersKwargs],
|
|
1205
|
-
):
|
|
1210
|
+
) -> JanusCausalLMOutputWithPast:
|
|
1206
1211
|
r"""
|
|
1207
1212
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1208
1213
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -1249,6 +1254,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
|
|
1249
1254
|
inputs_embeds=None,
|
|
1250
1255
|
cache_position=None,
|
|
1251
1256
|
logits_to_keep=None,
|
|
1257
|
+
is_first_iteration=False,
|
|
1252
1258
|
**kwargs,
|
|
1253
1259
|
):
|
|
1254
1260
|
# Overwritten -- extra custom processing
|
|
@@ -1260,12 +1266,15 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
|
|
1260
1266
|
attention_mask=attention_mask,
|
|
1261
1267
|
cache_position=cache_position,
|
|
1262
1268
|
logits_to_keep=logits_to_keep,
|
|
1269
|
+
is_first_iteration=is_first_iteration,
|
|
1263
1270
|
**kwargs,
|
|
1264
1271
|
)
|
|
1265
1272
|
|
|
1266
|
-
#
|
|
1267
|
-
#
|
|
1268
|
-
|
|
1273
|
+
# Pixel values are used only in the first iteration if available
|
|
1274
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
1275
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
1276
|
+
# iteration with a question and cached system prompt (continue generate from cache)
|
|
1277
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
1269
1278
|
model_inputs["pixel_values"] = pixel_values
|
|
1270
1279
|
|
|
1271
1280
|
return model_inputs
|
|
@@ -24,8 +24,7 @@ import torch.nn.functional as F
|
|
|
24
24
|
import torch.utils.checkpoint
|
|
25
25
|
from torch import nn
|
|
26
26
|
|
|
27
|
-
from
|
|
28
|
-
|
|
27
|
+
from ... import initialization as init
|
|
29
28
|
from ...activations import ACT2FN
|
|
30
29
|
from ...cache_utils import Cache
|
|
31
30
|
from ...configuration_utils import PreTrainedConfig
|
|
@@ -58,6 +57,7 @@ from ...utils import (
|
|
|
58
57
|
logging,
|
|
59
58
|
)
|
|
60
59
|
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
|
|
60
|
+
from ..blip.image_processing_blip import BlipImageProcessor
|
|
61
61
|
from ..blip_2.modeling_blip_2 import Blip2VisionModel
|
|
62
62
|
from ..chameleon.configuration_chameleon import ChameleonVQVAEConfig
|
|
63
63
|
from ..chameleon.modeling_chameleon import (
|
|
@@ -391,6 +391,11 @@ class JanusPreTrainedModel(PreTrainedModel):
|
|
|
391
391
|
|
|
392
392
|
_can_compile_fullgraph = True
|
|
393
393
|
|
|
394
|
+
def _init_weights(self, module):
|
|
395
|
+
super()._init_weights(module)
|
|
396
|
+
if isinstance(module, JanusVisionEmbeddings):
|
|
397
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
398
|
+
|
|
394
399
|
|
|
395
400
|
@dataclass
|
|
396
401
|
@auto_docstring(
|
|
@@ -823,6 +828,7 @@ class JanusVQVAE(ChameleonVQVAE):
|
|
|
823
828
|
def forward(
|
|
824
829
|
self,
|
|
825
830
|
pixel_values: torch.FloatTensor,
|
|
831
|
+
**kwargs,
|
|
826
832
|
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
827
833
|
batch_size = pixel_values.shape[0]
|
|
828
834
|
quant, embedding_loss, indices = self.encode(pixel_values)
|
|
@@ -941,7 +947,7 @@ class JanusModel(JanusPreTrainedModel):
|
|
|
941
947
|
use_cache: Optional[bool] = None,
|
|
942
948
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
943
949
|
**kwargs,
|
|
944
|
-
):
|
|
950
|
+
) -> JanusBaseModelOutputWithPast:
|
|
945
951
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
946
952
|
raise ValueError(
|
|
947
953
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
|
@@ -1018,7 +1024,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
|
|
1018
1024
|
use_cache: Optional[bool] = None,
|
|
1019
1025
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1020
1026
|
**kwargs: Unpack[TransformersKwargs],
|
|
1021
|
-
):
|
|
1027
|
+
) -> JanusCausalLMOutputWithPast:
|
|
1022
1028
|
r"""
|
|
1023
1029
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1024
1030
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -1065,6 +1071,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
|
|
1065
1071
|
inputs_embeds=None,
|
|
1066
1072
|
cache_position=None,
|
|
1067
1073
|
logits_to_keep=None,
|
|
1074
|
+
is_first_iteration=False,
|
|
1068
1075
|
**kwargs,
|
|
1069
1076
|
):
|
|
1070
1077
|
# Overwritten -- extra custom processing
|
|
@@ -1076,12 +1083,15 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
|
|
1076
1083
|
attention_mask=attention_mask,
|
|
1077
1084
|
cache_position=cache_position,
|
|
1078
1085
|
logits_to_keep=logits_to_keep,
|
|
1086
|
+
is_first_iteration=is_first_iteration,
|
|
1079
1087
|
**kwargs,
|
|
1080
1088
|
)
|
|
1081
1089
|
|
|
1082
|
-
#
|
|
1083
|
-
#
|
|
1084
|
-
|
|
1090
|
+
# Pixel values are used only in the first iteration if available
|
|
1091
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
1092
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
1093
|
+
# iteration with a question and cached system prompt (continue generate from cache)
|
|
1094
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
1085
1095
|
model_inputs["pixel_values"] = pixel_values
|
|
1086
1096
|
|
|
1087
1097
|
return model_inputs
|
|
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
38
38
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
39
39
|
from ...processing_utils import Unpack
|
|
40
40
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
41
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
41
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
42
42
|
from .configuration_jetmoe import JetMoeConfig
|
|
43
43
|
|
|
44
44
|
|
|
@@ -83,7 +83,7 @@ class JetMoeRotaryEmbedding(nn.Module):
|
|
|
83
83
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
84
84
|
|
|
85
85
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
86
|
-
self.original_inv_freq =
|
|
86
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
87
87
|
|
|
88
88
|
@staticmethod
|
|
89
89
|
def compute_default_rope_parameters(
|
|
@@ -122,7 +122,7 @@ class JetMoeRotaryEmbedding(nn.Module):
|
|
|
122
122
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
123
123
|
|
|
124
124
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
125
|
-
with
|
|
125
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
126
126
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
127
127
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
128
128
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -576,7 +576,7 @@ class JetMoePreTrainedModel(PreTrainedModel):
|
|
|
576
576
|
_supports_flash_attn = True
|
|
577
577
|
_supports_sdpa = True
|
|
578
578
|
_supports_flex_attn = True
|
|
579
|
-
_can_compile_fullgraph = False #
|
|
579
|
+
_can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
|
|
580
580
|
_supports_attention_backend = True
|
|
581
581
|
_can_record_outputs = {
|
|
582
582
|
"router_logits": OutputRecorder(nn.Linear, layer_name="gate", index=1),
|
|
@@ -435,6 +435,7 @@ class JetMoePreTrainedModel(MixtralPreTrainedModel):
|
|
|
435
435
|
_skip_keys_device_placement = ["past_key_values"]
|
|
436
436
|
_supports_flash_attn = True
|
|
437
437
|
_supports_sdpa = True
|
|
438
|
+
_can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
|
|
438
439
|
|
|
439
440
|
@torch.no_grad()
|
|
440
441
|
def _init_weights(self, module):
|
|
@@ -559,6 +559,7 @@ class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module):
|
|
|
559
559
|
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
|
560
560
|
super().__init__()
|
|
561
561
|
self.offset = 2
|
|
562
|
+
self.num_positions = num_positions
|
|
562
563
|
self.embedding_dim = embedding_dim
|
|
563
564
|
self.padding_idx = padding_idx
|
|
564
565
|
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
|
|
@@ -1138,6 +1139,7 @@ class Kosmos2PreTrainedModel(PreTrainedModel):
|
|
|
1138
1139
|
init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
|
1139
1140
|
init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
|
1140
1141
|
init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
|
1142
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
1141
1143
|
elif isinstance(module, Kosmos2VisionAttention):
|
|
1142
1144
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
1143
1145
|
out_proj_std = (module.embed_dim**-0.5) * factor
|
|
@@ -1170,6 +1172,11 @@ class Kosmos2PreTrainedModel(PreTrainedModel):
|
|
|
1170
1172
|
elif isinstance(module, nn.LayerNorm):
|
|
1171
1173
|
init.ones_(module.weight)
|
|
1172
1174
|
init.zeros_(module.bias)
|
|
1175
|
+
elif isinstance(module, Kosmos2TextSinusoidalPositionalEmbedding):
|
|
1176
|
+
emb_weights = module.get_embedding(
|
|
1177
|
+
module.num_positions + module.offset, module.embedding_dim, module.padding_idx
|
|
1178
|
+
)
|
|
1179
|
+
init.copy_(module.weights, emb_weights)
|
|
1173
1180
|
|
|
1174
1181
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
1175
1182
|
init.zeros_(module.bias)
|
|
@@ -1199,6 +1206,7 @@ class Kosmos2VisionModel(Kosmos2PreTrainedModel):
|
|
|
1199
1206
|
output_hidden_states: Optional[bool] = None,
|
|
1200
1207
|
interpolate_pos_encoding: bool = False,
|
|
1201
1208
|
return_dict: Optional[bool] = None,
|
|
1209
|
+
**kwargs,
|
|
1202
1210
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
|
1203
1211
|
return self.model(
|
|
1204
1212
|
pixel_values=pixel_values,
|
|
@@ -1381,12 +1389,16 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
|
|
1381
1389
|
inputs_embeds=None,
|
|
1382
1390
|
use_cache=None,
|
|
1383
1391
|
cache_position=None,
|
|
1392
|
+
is_first_iteration=False,
|
|
1384
1393
|
**model_kwargs,
|
|
1385
1394
|
):
|
|
1386
1395
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
1387
1396
|
|
|
1388
|
-
#
|
|
1389
|
-
|
|
1397
|
+
# Pixel values are used only in the first iteration if available
|
|
1398
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
1399
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
1400
|
+
# iteration with a question and cached system prompt (continue generate from cache)
|
|
1401
|
+
if not is_first_iteration and use_cache:
|
|
1390
1402
|
image_embeds = None
|
|
1391
1403
|
image_embeds_position_mask = None
|
|
1392
1404
|
|
|
@@ -1411,6 +1423,7 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
|
|
1411
1423
|
inputs_embeds=inputs_embeds,
|
|
1412
1424
|
use_cache=use_cache,
|
|
1413
1425
|
cache_position=cache_position,
|
|
1426
|
+
is_first_iteration=is_first_iteration,
|
|
1414
1427
|
**model_kwargs,
|
|
1415
1428
|
)
|
|
1416
1429
|
# Kosmos2 has offset for position ids, so we need to create them correctly in PositionEmbedding layer
|
|
@@ -264,8 +264,8 @@ class Kosmos2_5ImageProcessorFast(BaseImageProcessorFast):
|
|
|
264
264
|
|
|
265
265
|
encoded_outputs = BatchFeature(
|
|
266
266
|
data={
|
|
267
|
-
"flattened_patches":
|
|
268
|
-
"attention_mask":
|
|
267
|
+
"flattened_patches": flattened_patches,
|
|
268
|
+
"attention_mask": attention_masks,
|
|
269
269
|
"width": width,
|
|
270
270
|
"height": height,
|
|
271
271
|
"rows": rows,
|
|
@@ -619,6 +619,7 @@ class Kosmos2_5TextSinusoidalPositionalEmbedding(nn.Module):
|
|
|
619
619
|
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
|
620
620
|
super().__init__()
|
|
621
621
|
self.offset = 2
|
|
622
|
+
self.num_positions = num_positions
|
|
622
623
|
self.embedding_dim = embedding_dim
|
|
623
624
|
self.padding_idx = padding_idx
|
|
624
625
|
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
|
|
@@ -1253,6 +1254,11 @@ class Kosmos2_5PreTrainedModel(PreTrainedModel):
|
|
|
1253
1254
|
init.zeros_(module.bias)
|
|
1254
1255
|
elif isinstance(module, Kosmos2_5ImageToTextProjection):
|
|
1255
1256
|
init.normal_(module.latent_query, mean=0.0, std=1.0)
|
|
1257
|
+
elif isinstance(module, Kosmos2_5TextSinusoidalPositionalEmbedding):
|
|
1258
|
+
emb_weights = module.get_embedding(
|
|
1259
|
+
module.num_positions + module.offset, module.embedding_dim, module.padding_idx
|
|
1260
|
+
)
|
|
1261
|
+
init.copy_(module.weights, emb_weights)
|
|
1256
1262
|
|
|
1257
1263
|
|
|
1258
1264
|
class Kosmos2_5VisionModel(Kosmos2_5PreTrainedModel):
|
|
@@ -1602,6 +1608,7 @@ class Kosmos2_5TextForCausalLM(Kosmos2_5PreTrainedModel):
|
|
|
1602
1608
|
use_cache=None,
|
|
1603
1609
|
cache_position=None,
|
|
1604
1610
|
position_ids=None,
|
|
1611
|
+
is_first_iteration=False,
|
|
1605
1612
|
**model_kwargs,
|
|
1606
1613
|
):
|
|
1607
1614
|
input_shape = input_ids.shape
|
|
@@ -1806,6 +1813,7 @@ class Kosmos2_5ForConditionalGeneration(Kosmos2_5PreTrainedModel, GenerationMixi
|
|
|
1806
1813
|
use_cache=None,
|
|
1807
1814
|
cache_position=None,
|
|
1808
1815
|
position_ids=None,
|
|
1816
|
+
is_first_iteration=False,
|
|
1809
1817
|
**model_kwargs,
|
|
1810
1818
|
):
|
|
1811
1819
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -1819,10 +1827,11 @@ class Kosmos2_5ForConditionalGeneration(Kosmos2_5PreTrainedModel, GenerationMixi
|
|
|
1819
1827
|
use_cache=use_cache,
|
|
1820
1828
|
cache_position=cache_position,
|
|
1821
1829
|
position_ids=position_ids,
|
|
1830
|
+
is_first_iteration=is_first_iteration,
|
|
1822
1831
|
**model_kwargs,
|
|
1823
1832
|
)
|
|
1824
1833
|
|
|
1825
|
-
if
|
|
1834
|
+
if is_first_iteration or not use_cache:
|
|
1826
1835
|
# If we're in cached decoding stage, `flattened_patches` should be `None` because `input_ids` do not contain special image token anymore
|
|
1827
1836
|
# Otherwise we need `flattened_patches` to be passed to model
|
|
1828
1837
|
model_inputs["flattened_patches"] = flattened_patches
|
|
@@ -39,6 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
39
39
|
from ...modeling_utils import PreTrainedModel
|
|
40
40
|
from ...processing_utils import Unpack
|
|
41
41
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
|
42
|
+
from ...utils.generic import maybe_autocast
|
|
42
43
|
from ..auto import AutoModel
|
|
43
44
|
from .configuration_kyutai_speech_to_text import KyutaiSpeechToTextConfig
|
|
44
45
|
|
|
@@ -111,6 +112,11 @@ class KyutaiSpeechToTextPreTrainedModel(PreTrainedModel):
|
|
|
111
112
|
super()._init_weights(module)
|
|
112
113
|
if isinstance(module, KyutaiSpeechToTextFlexibleLinear):
|
|
113
114
|
init.normal_(module.weight)
|
|
115
|
+
if isinstance(module, KyutaiSpeechToTextEmbeddings):
|
|
116
|
+
audio_tokens_offsets = torch.arange(module.config.num_codebooks) * module.config.codebook_vocab_size
|
|
117
|
+
audio_tokens_offsets += module.config.vocab_size
|
|
118
|
+
audio_tokens_offsets = nn.functional.pad(audio_tokens_offsets, (1, 0))
|
|
119
|
+
init.copy_(module.audio_tokens_offsets, audio_tokens_offsets)
|
|
114
120
|
|
|
115
121
|
|
|
116
122
|
class KyutaiSpeechToTextConv1dPaddingCache:
|
|
@@ -201,6 +207,7 @@ class KyutaiSpeechToTextConv1dPaddingCache:
|
|
|
201
207
|
class KyutaiSpeechToTextEmbeddings(nn.Module):
|
|
202
208
|
def __init__(self, config):
|
|
203
209
|
super().__init__()
|
|
210
|
+
self.config = config
|
|
204
211
|
self.embed_tokens = nn.Embedding(
|
|
205
212
|
config.vocab_size + (config.num_codebooks * config.codebook_vocab_size) + 1,
|
|
206
213
|
config.hidden_size,
|
|
@@ -276,7 +283,7 @@ class KyutaiSpeechToTextRotaryEmbedding(nn.Module):
|
|
|
276
283
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
277
284
|
|
|
278
285
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
279
|
-
self.original_inv_freq =
|
|
286
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
280
287
|
|
|
281
288
|
@staticmethod
|
|
282
289
|
def compute_default_rope_parameters(
|
|
@@ -315,7 +322,7 @@ class KyutaiSpeechToTextRotaryEmbedding(nn.Module):
|
|
|
315
322
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
316
323
|
|
|
317
324
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
318
|
-
with
|
|
325
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
319
326
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
320
327
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
321
328
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -599,8 +606,8 @@ class KyutaiSpeechToTextFlashAttention2(KyutaiSpeechToTextAttention):
|
|
|
599
606
|
else torch.get_autocast_gpu_dtype()
|
|
600
607
|
)
|
|
601
608
|
# Handle the case where the model is quantized
|
|
602
|
-
elif hasattr(self.config, "
|
|
603
|
-
target_dtype = self.config.
|
|
609
|
+
elif hasattr(self.config, "quantization_config"):
|
|
610
|
+
target_dtype = self.config.dtype
|
|
604
611
|
else:
|
|
605
612
|
target_dtype = self.q_proj.weight.dtype
|
|
606
613
|
|
|
@@ -837,6 +844,7 @@ class KyutaiSpeechToTextModel(KyutaiSpeechToTextPreTrainedModel):
|
|
|
837
844
|
output_hidden_states: Optional[bool] = None,
|
|
838
845
|
return_dict: Optional[bool] = None,
|
|
839
846
|
cache_position: Optional[torch.LongTensor] = None,
|
|
847
|
+
**kwargs,
|
|
840
848
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
841
849
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
842
850
|
output_hidden_states = (
|
|
@@ -20,6 +20,7 @@ import numpy as np
|
|
|
20
20
|
import torch
|
|
21
21
|
import torch.nn as nn
|
|
22
22
|
|
|
23
|
+
from ... import initialization as init
|
|
23
24
|
from ...cache_utils import Cache
|
|
24
25
|
from ...feature_extraction_utils import BatchFeature
|
|
25
26
|
from ...generation import GenerationConfig, GenerationMixin
|
|
@@ -213,7 +214,13 @@ class KyutaiSpeechToTextFeatureExtractor(EncodecFeatureExtractor):
|
|
|
213
214
|
|
|
214
215
|
|
|
215
216
|
class KyutaiSpeechToTextPreTrainedModel(MoshiPreTrainedModel):
|
|
216
|
-
|
|
217
|
+
def _init_weights(self, module):
|
|
218
|
+
super()._init_weights(module)
|
|
219
|
+
if isinstance(module, KyutaiSpeechToTextEmbeddings):
|
|
220
|
+
audio_tokens_offsets = torch.arange(module.config.num_codebooks) * module.config.codebook_vocab_size
|
|
221
|
+
audio_tokens_offsets += module.config.vocab_size
|
|
222
|
+
audio_tokens_offsets = nn.functional.pad(audio_tokens_offsets, (1, 0))
|
|
223
|
+
init.copy_(module.audio_tokens_offsets, audio_tokens_offsets)
|
|
217
224
|
|
|
218
225
|
|
|
219
226
|
class KyutaiSpeechToTextConv1dPaddingCache(MimiConv1dPaddingCache):
|
|
@@ -223,6 +230,7 @@ class KyutaiSpeechToTextConv1dPaddingCache(MimiConv1dPaddingCache):
|
|
|
223
230
|
class KyutaiSpeechToTextEmbeddings(nn.Module):
|
|
224
231
|
def __init__(self, config):
|
|
225
232
|
super().__init__()
|
|
233
|
+
self.config = config
|
|
226
234
|
self.embed_tokens = nn.Embedding(
|
|
227
235
|
config.vocab_size + (config.num_codebooks * config.codebook_vocab_size) + 1,
|
|
228
236
|
config.hidden_size,
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from typing import TYPE_CHECKING
|
|
15
|
+
|
|
16
|
+
from ...utils import _LazyModule
|
|
17
|
+
from ...utils.import_utils import define_import_structure
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from .configuration_lasr import *
|
|
22
|
+
from .feature_extraction_lasr import *
|
|
23
|
+
from .modeling_lasr import *
|
|
24
|
+
from .tokenization_lasr import *
|
|
25
|
+
else:
|
|
26
|
+
import sys
|
|
27
|
+
|
|
28
|
+
_file = globals()["__file__"]
|
|
29
|
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|