transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
# Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import itertools
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from typing import Optional, Union
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from tokenizers import Tokenizer
|
|
22
|
+
from tokenizers.models import Unigram
|
|
23
|
+
from torch import nn
|
|
24
|
+
|
|
25
|
+
from ...masking_utils import create_bidirectional_mask
|
|
26
|
+
from ...modeling_outputs import BaseModelOutput
|
|
27
|
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
28
|
+
from ...processing_utils import Unpack
|
|
29
|
+
from ...tokenization_utils_tokenizers import TokenizersBackend
|
|
30
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
31
|
+
from ...utils.generic import check_model_inputs
|
|
32
|
+
from ..llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward
|
|
33
|
+
from ..parakeet.configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
|
|
34
|
+
from ..parakeet.modeling_parakeet import (
|
|
35
|
+
ParakeetEncoderBlock,
|
|
36
|
+
ParakeetEncoderConvolutionModule,
|
|
37
|
+
ParakeetForCTC,
|
|
38
|
+
ParakeetPreTrainedModel,
|
|
39
|
+
)
|
|
40
|
+
from ..parakeet.processing_parakeet import ParakeetProcessor
|
|
41
|
+
from ..t5.tokenization_t5 import T5Tokenizer
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LasrTokenizer(T5Tokenizer, TokenizersBackend):
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
eos_token="</s>",
|
|
48
|
+
unk_token="<unk>",
|
|
49
|
+
pad_token="<pad>",
|
|
50
|
+
extra_ids=100,
|
|
51
|
+
additional_special_tokens=None,
|
|
52
|
+
vocab=None,
|
|
53
|
+
vocab_file=None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
):
|
|
56
|
+
super().__init__(
|
|
57
|
+
eos_token=eos_token,
|
|
58
|
+
unk_token=unk_token,
|
|
59
|
+
pad_token=pad_token,
|
|
60
|
+
extra_ids=extra_ids,
|
|
61
|
+
additional_special_tokens=additional_special_tokens,
|
|
62
|
+
vocab=vocab,
|
|
63
|
+
vocab_file=vocab_file,
|
|
64
|
+
**kwargs,
|
|
65
|
+
)
|
|
66
|
+
self._tokenizer = Tokenizer(
|
|
67
|
+
Unigram(
|
|
68
|
+
self._vocab_scores,
|
|
69
|
+
unk_id=3,
|
|
70
|
+
byte_fallback=False,
|
|
71
|
+
)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _decode(
|
|
75
|
+
self,
|
|
76
|
+
token_ids: Union[int, list[int]],
|
|
77
|
+
skip_special_tokens: bool = False,
|
|
78
|
+
clean_up_tokenization_spaces: Optional[bool] = None,
|
|
79
|
+
group_tokens: bool = True,
|
|
80
|
+
**kwargs,
|
|
81
|
+
) -> str:
|
|
82
|
+
if isinstance(token_ids, int):
|
|
83
|
+
token_ids = [token_ids]
|
|
84
|
+
if group_tokens:
|
|
85
|
+
token_ids = [token_group[0] for token_group in itertools.groupby(token_ids)]
|
|
86
|
+
|
|
87
|
+
# for CTC we filter out the blank token, which is the pad token
|
|
88
|
+
token_ids = [token for token in token_ids if token != self.pad_token_id]
|
|
89
|
+
|
|
90
|
+
return TokenizersBackend._decode(
|
|
91
|
+
self,
|
|
92
|
+
token_ids=token_ids,
|
|
93
|
+
skip_special_tokens=skip_special_tokens,
|
|
94
|
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
|
95
|
+
**kwargs,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class LasrProcessor(ParakeetProcessor):
|
|
100
|
+
tokenizer_class = "ParakeetTokenizerFast"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class LasrEncoderConfig(ParakeetEncoderConfig):
|
|
104
|
+
r"""
|
|
105
|
+
This is the configuration class to store the configuration of a [`LasrEncoder`]. It is used to instantiate a
|
|
106
|
+
`LasrEncoder` model according to the specified arguments, defining the model architecture.
|
|
107
|
+
|
|
108
|
+
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
|
109
|
+
documentation from [`PreTrainedConfig`] for more information.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
hidden_size (`int`, *optional*, defaults to 512):
|
|
113
|
+
Dimension of the layers and the hidden states.
|
|
114
|
+
num_hidden_layers (`int`, *optional*, defaults to 17):
|
|
115
|
+
Number of hidden layers in the Transformer encoder.
|
|
116
|
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
|
117
|
+
Number of attention heads for each attention layer in the Transformer encoder.
|
|
118
|
+
intermediate_size (`int`, *optional*, defaults to 2048):
|
|
119
|
+
Dimension of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
|
120
|
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
|
121
|
+
The non-linear activation function (function or string) in the encoder and pooler.
|
|
122
|
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
|
123
|
+
Whether to use bias in the attention layers.
|
|
124
|
+
convolution_bias (`bool`, *optional*, defaults to `False`):
|
|
125
|
+
Whether to use bias in convolutions of the conformer's convolution module.
|
|
126
|
+
conv_kernel_size (`int`, *optional*, defaults to 32):
|
|
127
|
+
The kernel size of the convolution layers in the Conformer block.
|
|
128
|
+
subsampling_conv_channels (`int`, *optional*, defaults to 256):
|
|
129
|
+
The number of channels in the subsampling convolution layers.
|
|
130
|
+
subsampling_conv_kernel_size (`int`, *optional*, defaults to 5):
|
|
131
|
+
The kernel size of the subsampling convolution layers.
|
|
132
|
+
subsampling_conv_stride (`int`, *optional*, defaults to 2):
|
|
133
|
+
The stride of the subsampling convolution layers.
|
|
134
|
+
num_mel_bins (`int`, *optional*, defaults to 128):
|
|
135
|
+
Number of mel features.
|
|
136
|
+
dropout (`float`, *optional*, defaults to 0.1):
|
|
137
|
+
The dropout ratio for all fully connected layers in the embeddings, encoder, and pooler.
|
|
138
|
+
dropout_positions (`float`, *optional*, defaults to 0.0):
|
|
139
|
+
The dropout ratio for the positions in the input sequence.
|
|
140
|
+
layerdrop (`float`, *optional*, defaults to 0.1):
|
|
141
|
+
The dropout ratio for the layers in the encoder.
|
|
142
|
+
activation_dropout (`float`, *optional*, defaults to 0.1):
|
|
143
|
+
The dropout ratio for activations inside the fully connected layer.
|
|
144
|
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
|
145
|
+
The dropout ratio for the attention layers.
|
|
146
|
+
max_position_embeddings (`int`, *optional*, defaults to 10000):
|
|
147
|
+
The maximum sequence length that this model might ever be used with.
|
|
148
|
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
149
|
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
150
|
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
|
151
|
+
The epsilon used by the layer normalization layers.
|
|
152
|
+
feed_forward_residual_weights (`tuple[float, float]`, *optional*, defaults to `[1.5, 0.5]`):
|
|
153
|
+
The residual weights for the feed forward layers.
|
|
154
|
+
conv_residual_weights (`tuple[float, float]`, *optional*, defaults to `[2.0, 1.0]`):
|
|
155
|
+
The residual weights for the convolution layers.
|
|
156
|
+
batch_norm_momentum (`float`, *optional*, defaults to 0.01):
|
|
157
|
+
The momentum for the batch normalization layers.
|
|
158
|
+
rope_parameters (`RopeParameters`, *optional*):
|
|
159
|
+
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
|
|
160
|
+
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
|
|
161
|
+
with longer `max_position_embeddings`.
|
|
162
|
+
|
|
163
|
+
Example:
|
|
164
|
+
```python
|
|
165
|
+
>>> from transformers import LasrEncoderModel, LasrEncoderConfig
|
|
166
|
+
|
|
167
|
+
>>> # Initializing a `LasrEncoder` configuration
|
|
168
|
+
>>> configuration = LasrEncoderConfig()
|
|
169
|
+
|
|
170
|
+
>>> # Initializing a model from the configuration
|
|
171
|
+
>>> model = LasrEncoderModel(configuration)
|
|
172
|
+
|
|
173
|
+
>>> # Accessing the model configuration
|
|
174
|
+
>>> configuration = model.config
|
|
175
|
+
```
|
|
176
|
+
|
|
177
|
+
This configuration class is based on the LasrEncoder architecture from Google Health AI. You can find more details
|
|
178
|
+
and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO).
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
hidden_size=512,
|
|
184
|
+
num_hidden_layers=17,
|
|
185
|
+
num_attention_heads=8,
|
|
186
|
+
intermediate_size=2048,
|
|
187
|
+
hidden_act="silu",
|
|
188
|
+
attention_bias=False,
|
|
189
|
+
convolution_bias=False,
|
|
190
|
+
conv_kernel_size=32,
|
|
191
|
+
subsampling_conv_channels=256,
|
|
192
|
+
subsampling_conv_kernel_size=5,
|
|
193
|
+
subsampling_conv_stride=2,
|
|
194
|
+
num_mel_bins=128,
|
|
195
|
+
dropout=0.1,
|
|
196
|
+
dropout_positions=0.0,
|
|
197
|
+
layerdrop=0.1,
|
|
198
|
+
activation_dropout=0.1,
|
|
199
|
+
attention_dropout=0.1,
|
|
200
|
+
max_position_embeddings=10000,
|
|
201
|
+
initializer_range=0.02,
|
|
202
|
+
layer_norm_eps=1e-6,
|
|
203
|
+
feed_forward_residual_weights=[1.5, 0.5],
|
|
204
|
+
conv_residual_weights=[2.0, 1.0],
|
|
205
|
+
batch_norm_momentum=0.01,
|
|
206
|
+
rope_parameters=None,
|
|
207
|
+
**kwargs,
|
|
208
|
+
):
|
|
209
|
+
self.rope_parameters = rope_parameters
|
|
210
|
+
self.layer_norm_eps = layer_norm_eps
|
|
211
|
+
self.feed_forward_residual_weights = feed_forward_residual_weights
|
|
212
|
+
self.conv_residual_weights = conv_residual_weights
|
|
213
|
+
self.batch_norm_momentum = batch_norm_momentum
|
|
214
|
+
|
|
215
|
+
super().__init__(
|
|
216
|
+
hidden_size=hidden_size,
|
|
217
|
+
num_hidden_layers=num_hidden_layers,
|
|
218
|
+
num_attention_heads=num_attention_heads,
|
|
219
|
+
intermediate_size=intermediate_size,
|
|
220
|
+
hidden_act=hidden_act,
|
|
221
|
+
attention_bias=attention_bias,
|
|
222
|
+
convolution_bias=convolution_bias,
|
|
223
|
+
conv_kernel_size=conv_kernel_size,
|
|
224
|
+
subsampling_conv_channels=subsampling_conv_channels,
|
|
225
|
+
num_mel_bins=num_mel_bins,
|
|
226
|
+
subsampling_conv_kernel_size=subsampling_conv_kernel_size,
|
|
227
|
+
subsampling_conv_stride=subsampling_conv_stride,
|
|
228
|
+
dropout=dropout,
|
|
229
|
+
dropout_positions=dropout_positions,
|
|
230
|
+
layerdrop=layerdrop,
|
|
231
|
+
activation_dropout=activation_dropout,
|
|
232
|
+
attention_dropout=attention_dropout,
|
|
233
|
+
max_position_embeddings=max_position_embeddings,
|
|
234
|
+
initializer_range=initializer_range,
|
|
235
|
+
**kwargs,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
del self.subsampling_factor
|
|
239
|
+
del self.scale_input
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class LasrCTCConfig(ParakeetCTCConfig):
|
|
243
|
+
r"""
|
|
244
|
+
This is the configuration class to store the configuration of a [`LasrForCTC`]. It is used to instantiate a
|
|
245
|
+
Lasr CTC model according to the specified arguments, defining the model architecture.
|
|
246
|
+
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
|
247
|
+
documentation from [`PreTrainedConfig`] for more information.
|
|
248
|
+
Args:
|
|
249
|
+
vocab_size (`int`, *optional*, defaults to 512):
|
|
250
|
+
Vocabulary size of the model.
|
|
251
|
+
ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`):
|
|
252
|
+
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
|
|
253
|
+
instance of [`LasrForCTC`].
|
|
254
|
+
ctc_zero_infinity (`bool`, *optional*, defaults to `True`):
|
|
255
|
+
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
|
|
256
|
+
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
|
|
257
|
+
of [`LasrForCTC`].
|
|
258
|
+
encoder_config (`Union[dict, LasrEncoderConfig]`, *optional*):
|
|
259
|
+
The config object or dictionary of the encoder.
|
|
260
|
+
pad_token_id (`int`, *optional*, defaults to 0):
|
|
261
|
+
Padding token id. Also used as blank token id.
|
|
262
|
+
Example:
|
|
263
|
+
```python
|
|
264
|
+
>>> from transformers import LasrForCTC, LasrCTCConfig
|
|
265
|
+
>>> # Initializing a Lasr configuration
|
|
266
|
+
>>> configuration = LasrCTCConfig()
|
|
267
|
+
>>> # Initializing a model from the configuration
|
|
268
|
+
>>> model = LasrForCTC(configuration)
|
|
269
|
+
>>> # Accessing the model configuration
|
|
270
|
+
>>> configuration = model.config
|
|
271
|
+
```
|
|
272
|
+
This configuration class is based on the Lasr CTC architecture from Google Health AI. You can find more details
|
|
273
|
+
and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO).
|
|
274
|
+
"""
|
|
275
|
+
|
|
276
|
+
def __init__(
|
|
277
|
+
self,
|
|
278
|
+
vocab_size=512,
|
|
279
|
+
ctc_loss_reduction="mean",
|
|
280
|
+
ctc_zero_infinity=True,
|
|
281
|
+
encoder_config: Union[dict, LasrEncoderConfig] = None,
|
|
282
|
+
pad_token_id=0,
|
|
283
|
+
**kwargs,
|
|
284
|
+
):
|
|
285
|
+
super().__init__(
|
|
286
|
+
vocab_size=vocab_size,
|
|
287
|
+
ctc_loss_reduction=ctc_loss_reduction,
|
|
288
|
+
ctc_zero_infinity=ctc_zero_infinity,
|
|
289
|
+
encoder_config=encoder_config,
|
|
290
|
+
pad_token_id=pad_token_id,
|
|
291
|
+
**kwargs,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class LasrEncoderSubsampling(nn.Module):
|
|
296
|
+
def __init__(self, config: LasrEncoderConfig):
|
|
297
|
+
super().__init__()
|
|
298
|
+
self.dense_0 = nn.Linear(config.num_mel_bins, config.hidden_size)
|
|
299
|
+
self.conv_0 = nn.Conv1d(
|
|
300
|
+
config.hidden_size,
|
|
301
|
+
config.hidden_size,
|
|
302
|
+
kernel_size=config.subsampling_conv_kernel_size,
|
|
303
|
+
stride=config.subsampling_conv_stride,
|
|
304
|
+
)
|
|
305
|
+
self.conv_1 = nn.Conv1d(
|
|
306
|
+
config.hidden_size,
|
|
307
|
+
config.subsampling_conv_channels,
|
|
308
|
+
kernel_size=config.subsampling_conv_kernel_size,
|
|
309
|
+
stride=config.subsampling_conv_stride,
|
|
310
|
+
)
|
|
311
|
+
self.dense_1 = nn.Linear(config.subsampling_conv_channels, config.hidden_size)
|
|
312
|
+
self.act_fn = nn.ReLU()
|
|
313
|
+
|
|
314
|
+
def forward(self, input_features: torch.Tensor) -> torch.Tensor:
|
|
315
|
+
hidden_states = self.act_fn(self.dense_0(input_features))
|
|
316
|
+
hidden_states = hidden_states.transpose(1, 2)
|
|
317
|
+
hidden_states = self.act_fn(self.conv_0(hidden_states))
|
|
318
|
+
hidden_states = self.act_fn(self.conv_1(hidden_states))
|
|
319
|
+
hidden_states = hidden_states.transpose(1, 2)
|
|
320
|
+
return self.dense_1(hidden_states)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class LasrEncoderRotaryEmbedding(LlamaRotaryEmbedding): ...
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class LasrEncoderAttention(LlamaAttention):
|
|
327
|
+
def __init__(self, config: LasrEncoderConfig, layer_idx: int):
|
|
328
|
+
super().__init__(config, layer_idx)
|
|
329
|
+
self.is_causal = False
|
|
330
|
+
|
|
331
|
+
def forward(
|
|
332
|
+
self,
|
|
333
|
+
hidden_states: torch.Tensor,
|
|
334
|
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
335
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
336
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
337
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
338
|
+
input_shape = hidden_states.shape[:-1]
|
|
339
|
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
340
|
+
|
|
341
|
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
342
|
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
343
|
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
344
|
+
|
|
345
|
+
cos, sin = position_embeddings
|
|
346
|
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
347
|
+
|
|
348
|
+
attention_interface: Callable = eager_attention_forward
|
|
349
|
+
if self.config._attn_implementation != "eager":
|
|
350
|
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
351
|
+
|
|
352
|
+
attn_output, attn_weights = attention_interface(
|
|
353
|
+
self,
|
|
354
|
+
query_states,
|
|
355
|
+
key_states,
|
|
356
|
+
value_states,
|
|
357
|
+
attention_mask,
|
|
358
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
359
|
+
scaling=self.scaling,
|
|
360
|
+
**kwargs,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
364
|
+
attn_output = self.o_proj(attn_output)
|
|
365
|
+
return attn_output, attn_weights
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class LasrEncoderConvolutionModule(ParakeetEncoderConvolutionModule):
|
|
369
|
+
def __init__(self, config: LasrEncoderConfig, module_config=None):
|
|
370
|
+
super().__init__(config, module_config)
|
|
371
|
+
self.padding = "same"
|
|
372
|
+
self.norm = nn.BatchNorm1d(config.hidden_size, momentum=config.batch_norm_momentum)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class LasrEncoderBlock(ParakeetEncoderBlock):
|
|
376
|
+
def __init__(self, config: LasrEncoderConfig, layer_idx: int):
|
|
377
|
+
super().__init__(config, layer_idx)
|
|
378
|
+
|
|
379
|
+
self.feed_forward_residual_weights = config.feed_forward_residual_weights
|
|
380
|
+
self.conv_residual_weights = config.conv_residual_weights
|
|
381
|
+
|
|
382
|
+
self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
|
|
383
|
+
self.norm_self_att = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
|
|
384
|
+
self.norm_conv = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
|
|
385
|
+
self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
|
|
386
|
+
self.norm_out = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
|
|
387
|
+
|
|
388
|
+
def forward(
|
|
389
|
+
self,
|
|
390
|
+
hidden_states: torch.Tensor,
|
|
391
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
392
|
+
position_embeddings: Optional[torch.Tensor] = None,
|
|
393
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
394
|
+
) -> torch.Tensor:
|
|
395
|
+
residual = hidden_states
|
|
396
|
+
hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
|
|
397
|
+
hidden_states = (
|
|
398
|
+
self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
normalized_hidden_states = self.norm_self_att(hidden_states)
|
|
402
|
+
attn_output, _ = self.self_attn(
|
|
403
|
+
hidden_states=normalized_hidden_states,
|
|
404
|
+
attention_mask=attention_mask,
|
|
405
|
+
position_embeddings=position_embeddings,
|
|
406
|
+
**kwargs,
|
|
407
|
+
)
|
|
408
|
+
hidden_states = hidden_states + attn_output
|
|
409
|
+
|
|
410
|
+
conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
|
|
411
|
+
hidden_states = self.conv_residual_weights[0] * hidden_states + self.conv_residual_weights[1] * conv_output
|
|
412
|
+
|
|
413
|
+
residual = hidden_states
|
|
414
|
+
hidden_states = self.feed_forward2(self.norm_feed_forward2(hidden_states))
|
|
415
|
+
hidden_states = (
|
|
416
|
+
self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
hidden_states = self.norm_out(hidden_states)
|
|
420
|
+
|
|
421
|
+
return hidden_states
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
class LasrPreTrainedModel(ParakeetPreTrainedModel):
|
|
425
|
+
def _init_weights(self, module):
|
|
426
|
+
PreTrainedModel._init_weights(module)
|
|
427
|
+
|
|
428
|
+
def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
|
|
429
|
+
encoder_config = self.config.encoder_config if isinstance(self.config, LasrCTCConfig) else self.config
|
|
430
|
+
kernel_size = encoder_config.subsampling_conv_kernel_size
|
|
431
|
+
stride = encoder_config.subsampling_conv_stride
|
|
432
|
+
|
|
433
|
+
num_layers = 2
|
|
434
|
+
for _ in range(num_layers):
|
|
435
|
+
input_lengths = (input_lengths - kernel_size) // stride + 1
|
|
436
|
+
|
|
437
|
+
return input_lengths
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
@auto_docstring(
|
|
441
|
+
custom_intro="""
|
|
442
|
+
The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100).
|
|
443
|
+
"""
|
|
444
|
+
)
|
|
445
|
+
class LasrEncoder(LasrPreTrainedModel):
|
|
446
|
+
config: LasrEncoderConfig
|
|
447
|
+
base_model_prefix = "encoder"
|
|
448
|
+
|
|
449
|
+
def __init__(self, config: LasrEncoderConfig):
|
|
450
|
+
super().__init__(config)
|
|
451
|
+
self.gradient_checkpointing = False
|
|
452
|
+
|
|
453
|
+
self.dropout = config.dropout
|
|
454
|
+
self.dropout_positions = config.dropout_positions
|
|
455
|
+
self.layerdrop = config.layerdrop
|
|
456
|
+
|
|
457
|
+
self.subsampler = LasrEncoderSubsampling(config)
|
|
458
|
+
self.rotary_emb = LasrEncoderRotaryEmbedding(config)
|
|
459
|
+
self.layers = nn.ModuleList(
|
|
460
|
+
[LasrEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
461
|
+
)
|
|
462
|
+
self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
|
|
463
|
+
|
|
464
|
+
self.post_init()
|
|
465
|
+
|
|
466
|
+
@auto_docstring
|
|
467
|
+
@check_model_inputs()
|
|
468
|
+
@can_return_tuple
|
|
469
|
+
def forward(
|
|
470
|
+
self,
|
|
471
|
+
input_features: torch.Tensor,
|
|
472
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
473
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
474
|
+
) -> BaseModelOutput:
|
|
475
|
+
r"""
|
|
476
|
+
Example:
|
|
477
|
+
|
|
478
|
+
```python
|
|
479
|
+
>>> from transformers import AutoProcessor, LasrEncoder
|
|
480
|
+
>>> from datasets import load_dataset, Audio
|
|
481
|
+
|
|
482
|
+
>>> model_id = TODO
|
|
483
|
+
>>> processor = AutoProcessor.from_pretrained(model_id)
|
|
484
|
+
>>> encoder = ParakeetEncoder.from_pretrained(model_id)
|
|
485
|
+
|
|
486
|
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
487
|
+
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
|
488
|
+
|
|
489
|
+
>>> inputs = processor(ds[0]["audio"]["array"])
|
|
490
|
+
>>> encoder_outputs = encoder(**inputs)
|
|
491
|
+
|
|
492
|
+
>>> print(encoder_outputs.last_hidden_state.shape)
|
|
493
|
+
```
|
|
494
|
+
"""
|
|
495
|
+
|
|
496
|
+
hidden_states = self.subsampler(input_features)
|
|
497
|
+
cos, sin = self.rotary_emb(
|
|
498
|
+
hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
502
|
+
cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training)
|
|
503
|
+
sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training)
|
|
504
|
+
|
|
505
|
+
if attention_mask is not None:
|
|
506
|
+
attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
|
|
507
|
+
|
|
508
|
+
attention_mask = create_bidirectional_mask(
|
|
509
|
+
config=self.config,
|
|
510
|
+
input_embeds=hidden_states,
|
|
511
|
+
attention_mask=attention_mask,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
for encoder_layer in self.layers:
|
|
515
|
+
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
|
516
|
+
to_drop = False
|
|
517
|
+
if self.training:
|
|
518
|
+
dropout_probability = torch.rand([])
|
|
519
|
+
if dropout_probability < self.layerdrop: # skip the layer
|
|
520
|
+
to_drop = True
|
|
521
|
+
|
|
522
|
+
if not to_drop:
|
|
523
|
+
hidden_states = encoder_layer(
|
|
524
|
+
hidden_states,
|
|
525
|
+
attention_mask=attention_mask,
|
|
526
|
+
position_embeddings=(cos, sin),
|
|
527
|
+
**kwargs,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
hidden_states = self.out_norm(hidden_states)
|
|
531
|
+
|
|
532
|
+
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
class LasrForCTC(ParakeetForCTC):
|
|
536
|
+
def generate(**super_kwargs):
|
|
537
|
+
r"""
|
|
538
|
+
Example:
|
|
539
|
+
|
|
540
|
+
```python
|
|
541
|
+
>>> from transformers import AutoProcessor, LasrForCTC
|
|
542
|
+
>>> from datasets import load_dataset, Audio
|
|
543
|
+
|
|
544
|
+
>>> model_id = TODO
|
|
545
|
+
>>> processor = AutoProcessor.from_pretrained(model_id)
|
|
546
|
+
>>> model = LasrForCTC.from_pretrained(model_id)
|
|
547
|
+
|
|
548
|
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
549
|
+
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
|
550
|
+
|
|
551
|
+
>>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
|
|
552
|
+
>>> predicted_ids = model.generate(**inputs)
|
|
553
|
+
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
|
554
|
+
|
|
555
|
+
>>> print(transcription)
|
|
556
|
+
```
|
|
557
|
+
"""
|
|
558
|
+
return super().generate(**super_kwargs)
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
__all__ = [
|
|
562
|
+
"LasrForCTC",
|
|
563
|
+
"LasrEncoder",
|
|
564
|
+
"LasrPreTrainedModel",
|
|
565
|
+
"LasrProcessor",
|
|
566
|
+
"LasrEncoderConfig",
|
|
567
|
+
"LasrCTCConfig",
|
|
568
|
+
"LasrTokenizer",
|
|
569
|
+
]
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
2
|
+
# This file was automatically generated from src/transformers/models/lasr/modular_lasr.py.
|
|
3
|
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
4
|
+
# the file from the modular. If any change should be done, please apply the change to the
|
|
5
|
+
# modular_lasr.py file directly. One of our CI enforces this.
|
|
6
|
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
7
|
+
# coding=utf-8
|
|
8
|
+
# Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
11
|
+
# you may not use this file except in compliance with the License.
|
|
12
|
+
# You may obtain a copy of the License at
|
|
13
|
+
#
|
|
14
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
15
|
+
#
|
|
16
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
17
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
18
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
19
|
+
# See the License for the specific language governing permissions and
|
|
20
|
+
# limitations under the License.
|
|
21
|
+
|
|
22
|
+
from typing import Optional, Union
|
|
23
|
+
|
|
24
|
+
from ...audio_utils import AudioInput, make_list_of_audio
|
|
25
|
+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
|
26
|
+
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
|
27
|
+
from ...utils import logging
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
logger = logging.get_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LasrProcessorKwargs(ProcessingKwargs, total=False):
|
|
34
|
+
_defaults = {
|
|
35
|
+
"audio_kwargs": {
|
|
36
|
+
"sampling_rate": 16000,
|
|
37
|
+
"padding": "longest",
|
|
38
|
+
"return_attention_mask": True,
|
|
39
|
+
},
|
|
40
|
+
"text_kwargs": {
|
|
41
|
+
"padding": True,
|
|
42
|
+
"padding_side": "right",
|
|
43
|
+
"add_special_tokens": False,
|
|
44
|
+
},
|
|
45
|
+
"common_kwargs": {"return_tensors": "pt"},
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class LasrProcessor(ProcessorMixin):
|
|
50
|
+
tokenizer_class = "ParakeetTokenizerFast"
|
|
51
|
+
|
|
52
|
+
def __init__(self, feature_extractor, tokenizer):
|
|
53
|
+
super().__init__(feature_extractor, tokenizer)
|
|
54
|
+
|
|
55
|
+
def __call__(
|
|
56
|
+
self,
|
|
57
|
+
audio: AudioInput,
|
|
58
|
+
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None,
|
|
59
|
+
sampling_rate: Optional[int] = None,
|
|
60
|
+
**kwargs: Unpack[LasrProcessorKwargs],
|
|
61
|
+
):
|
|
62
|
+
audio = make_list_of_audio(audio)
|
|
63
|
+
|
|
64
|
+
output_kwargs = self._merge_kwargs(
|
|
65
|
+
LasrProcessorKwargs,
|
|
66
|
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
67
|
+
**kwargs,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if sampling_rate is None:
|
|
71
|
+
logger.warning_once(
|
|
72
|
+
f"You've provided audio without specifying the sampling rate. It will be assumed to be {output_kwargs['audio_kwargs']['sampling_rate']}, which can result in silent errors."
|
|
73
|
+
)
|
|
74
|
+
elif sampling_rate != output_kwargs["audio_kwargs"]["sampling_rate"]:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"The sampling rate of the audio ({sampling_rate}) does not match the sampling rate of the processor ({output_kwargs['audio_kwargs']['sampling_rate']}). Please provide resampled the audio to the expected sampling rate."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if audio is not None:
|
|
80
|
+
inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
|
|
81
|
+
if text is not None:
|
|
82
|
+
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
|
83
|
+
|
|
84
|
+
if text is None:
|
|
85
|
+
return inputs
|
|
86
|
+
else:
|
|
87
|
+
inputs["labels"] = encodings["input_ids"]
|
|
88
|
+
return inputs
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def model_input_names(self):
|
|
92
|
+
feature_extractor_input_names = self.feature_extractor.model_input_names
|
|
93
|
+
return feature_extractor_input_names + ["labels"]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
__all__ = ["LasrProcessor"]
|