optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
16
|
+
|
|
17
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
18
|
+
# you may not use this file except in compliance with the License.
|
|
19
|
+
# You may obtain a copy of the License at:
|
|
20
|
+
|
|
21
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
22
|
+
|
|
23
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
24
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
25
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
26
|
+
# See the License for the specific language governing permissions and
|
|
27
|
+
# limitations under the License.
|
|
28
|
+
|
|
29
|
+
"""
|
|
30
|
+
Generation utilities for Whisper.
|
|
31
|
+
Modified from `transformers.models.whisper.generation_whisper.py`
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from typing import Any, Dict, Optional, Union
|
|
35
|
+
|
|
36
|
+
import torch
|
|
37
|
+
import transformers
|
|
38
|
+
from packaging import version
|
|
39
|
+
from transformers import GenerationMixin
|
|
40
|
+
from transformers.generation.configuration_utils import GenerationConfig
|
|
41
|
+
from transformers.modeling_outputs import ModelOutput
|
|
42
|
+
from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
|
|
46
|
+
def generate(
|
|
47
|
+
self,
|
|
48
|
+
input_features: Optional[torch.Tensor] = None,
|
|
49
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
50
|
+
generation_config: Optional[GenerationConfig] = None,
|
|
51
|
+
return_segments: Optional[bool] = None,
|
|
52
|
+
return_timestamps: Optional[bool] = None,
|
|
53
|
+
return_token_timestamps: Optional[bool] = None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
) -> Union[ModelOutput, Dict[str, Any], torch.LongTensor]:
|
|
56
|
+
"""
|
|
57
|
+
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
58
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate) for more details.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
input_features(torch.Tensor, optional): The input features to the model.
|
|
62
|
+
attention_mask(torch.Tensor, optional): Attention mask needs to be passed when doing long-form transcription using a batch size > 1.
|
|
63
|
+
generation_config(GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
|
|
64
|
+
If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
|
|
65
|
+
Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
|
|
66
|
+
return_segments(bool, optional): Whether to return segments.
|
|
67
|
+
return_timestamps(bool, optional): Whether to return the timestamps with the text. For audios longer than 30 seconds, it is necessary to set return_timestamps=True.
|
|
68
|
+
return_token_timestamps(bool, optional): Whether to return token timestamps.
|
|
69
|
+
kwargs(dict[str, Any], optional): Additional arguments passed to the generate function.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
|
|
73
|
+
"""
|
|
74
|
+
if kwargs.get("num_beams", None) is not None:
|
|
75
|
+
if kwargs.get("num_beams") != 1:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"Beam search is not supported in RBLNWhisperGenerationMixin. "
|
|
78
|
+
"Received num_beams={num_beams}, but only num_beams=1 is allowed. "
|
|
79
|
+
"Please set num_beams=1 for greedy search or adjust your configuration."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return super().generate(
|
|
83
|
+
input_features,
|
|
84
|
+
attention_mask=attention_mask,
|
|
85
|
+
generation_config=generation_config,
|
|
86
|
+
return_segments=return_segments,
|
|
87
|
+
return_timestamps=return_timestamps,
|
|
88
|
+
return_token_timestamps=return_token_timestamps,
|
|
89
|
+
**kwargs,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def _postprocess_outputs(
|
|
93
|
+
self,
|
|
94
|
+
seek_outputs,
|
|
95
|
+
decoder_input_ids,
|
|
96
|
+
return_token_timestamps,
|
|
97
|
+
generation_config,
|
|
98
|
+
is_shortform,
|
|
99
|
+
seek,
|
|
100
|
+
batch_idx_map,
|
|
101
|
+
):
|
|
102
|
+
# remove all previously passed decoder input ids
|
|
103
|
+
# should happen only if it is the first generated segment
|
|
104
|
+
start_idx = decoder_input_ids.shape[-1]
|
|
105
|
+
|
|
106
|
+
if isinstance(seek_outputs, torch.Tensor):
|
|
107
|
+
return seek_outputs[:, start_idx:], seek_outputs
|
|
108
|
+
|
|
109
|
+
if return_token_timestamps and not self.rbln_token_timestamps:
|
|
110
|
+
raise RuntimeError(
|
|
111
|
+
"To use .generate() with return_token_timestamps=True, the model must be compiled with rbln_token_timestamps=True. "
|
|
112
|
+
"You can compile the model by calling .from_pretrained() with export=True and rbln_token_timestamps=True as keyword arguments, "
|
|
113
|
+
"or you can generate with return_token_timestamps=False."
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
|
117
|
+
num_frames = getattr(generation_config, "num_frames", None)
|
|
118
|
+
|
|
119
|
+
if num_frames is not None:
|
|
120
|
+
num_frames = num_frames - seek
|
|
121
|
+
num_frames = num_frames[batch_idx_map]
|
|
122
|
+
|
|
123
|
+
if version.parse(transformers.__version__) >= version.parse("4.46.0"):
|
|
124
|
+
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
|
125
|
+
seek_outputs,
|
|
126
|
+
generation_config.alignment_heads,
|
|
127
|
+
num_frames=num_frames,
|
|
128
|
+
num_input_ids=decoder_input_ids.shape[-1],
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
|
132
|
+
seek_outputs,
|
|
133
|
+
generation_config.alignment_heads,
|
|
134
|
+
num_frames=num_frames,
|
|
135
|
+
)
|
|
136
|
+
seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
|
|
137
|
+
|
|
138
|
+
def split_by_batch_index(values, key, batch_idx):
|
|
139
|
+
if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
|
|
140
|
+
return [v[batch_idx].cpu() for v in values]
|
|
141
|
+
if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
|
|
142
|
+
return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
|
|
143
|
+
elif key == "past_key_values":
|
|
144
|
+
# we don't save `past_key_values in rbln
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
return values[batch_idx].cpu()
|
|
148
|
+
|
|
149
|
+
sequence_tokens = seek_outputs["sequences"]
|
|
150
|
+
|
|
151
|
+
valid_seekoutputs = []
|
|
152
|
+
for k, v in seek_outputs.items():
|
|
153
|
+
if v is not None and len(v) > 0 and v[0] is not None:
|
|
154
|
+
valid_seekoutputs.append((k, v))
|
|
155
|
+
seek_outputs = [
|
|
156
|
+
{k: split_by_batch_index(v, k, i) for k, v in valid_seekoutputs} for i in range(sequence_tokens.shape[0])
|
|
157
|
+
]
|
|
158
|
+
|
|
159
|
+
return sequence_tokens, seek_outputs
|
|
@@ -0,0 +1,475 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
import rebel
|
|
19
|
+
import torch
|
|
20
|
+
from rebel.compile_context import CompileContext
|
|
21
|
+
from transformers import AutoModelForSpeechSeq2Seq, WhisperForConditionalGeneration, WhisperModel
|
|
22
|
+
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
|
23
|
+
|
|
24
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
25
|
+
from ....modeling import RBLNModel
|
|
26
|
+
from ....utils.logging import get_logger
|
|
27
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
28
|
+
from .configuration_whisper import RBLNWhisperForConditionalGenerationConfig
|
|
29
|
+
from .generation_whisper import RBLNWhisperGenerationMixin
|
|
30
|
+
from .whisper_architecture import WhisperWrapper
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
logger = get_logger(__name__)
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from transformers import (
|
|
37
|
+
AutoFeatureExtractor,
|
|
38
|
+
AutoProcessor,
|
|
39
|
+
AutoTokenizer,
|
|
40
|
+
GenerationConfig,
|
|
41
|
+
PretrainedConfig,
|
|
42
|
+
PreTrainedModel,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
47
|
+
mandatory_members = ["main_input_name"]
|
|
48
|
+
|
|
49
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
|
|
50
|
+
output = super().forward(*args, **kwargs)
|
|
51
|
+
return BaseModelOutput(last_hidden_state=output)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
55
|
+
mandatory_members = ["main_input_name"]
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
runtime: rebel.Runtime,
|
|
60
|
+
batch_size: int,
|
|
61
|
+
dec_max_seq_len: int,
|
|
62
|
+
use_attention_mask: Optional[bool] = None,
|
|
63
|
+
**kwargs: Any,
|
|
64
|
+
) -> None:
|
|
65
|
+
super().__init__(runtime, **kwargs)
|
|
66
|
+
self.batch_size = batch_size
|
|
67
|
+
self.dec_max_seq_len = dec_max_seq_len
|
|
68
|
+
self.use_attention_mask = use_attention_mask
|
|
69
|
+
self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
|
|
70
|
+
|
|
71
|
+
def forward(
|
|
72
|
+
self,
|
|
73
|
+
decoder_input_ids: torch.Tensor = None,
|
|
74
|
+
decoder_attention_mask: torch.Tensor = None,
|
|
75
|
+
cache_position: torch.Tensor = None,
|
|
76
|
+
block_tables: torch.Tensor = None,
|
|
77
|
+
):
|
|
78
|
+
inputs_bsz = decoder_input_ids.shape[0]
|
|
79
|
+
padded_bsz = self.batch_size - inputs_bsz
|
|
80
|
+
|
|
81
|
+
if padded_bsz > 0:
|
|
82
|
+
decoder_input_ids = torch.nn.functional.pad(decoder_input_ids, (0, 0, 0, padded_bsz))
|
|
83
|
+
|
|
84
|
+
if self.use_attention_mask:
|
|
85
|
+
for b_idx in range(self.batch_size):
|
|
86
|
+
decoding_step = cache_position[b_idx].item()
|
|
87
|
+
if not (0 <= decoding_step < self.dec_max_seq_len):
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
90
|
+
)
|
|
91
|
+
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
|
92
|
+
|
|
93
|
+
if block_tables is None:
|
|
94
|
+
block_tables = self.default_block_tables
|
|
95
|
+
|
|
96
|
+
outputs = super().forward(
|
|
97
|
+
decoder_input_ids,
|
|
98
|
+
decoder_attention_mask if self.use_attention_mask else None,
|
|
99
|
+
cache_position,
|
|
100
|
+
block_tables=block_tables,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if isinstance(outputs, torch.Tensor):
|
|
104
|
+
return Seq2SeqLMOutput(logits=outputs[:inputs_bsz], cross_attentions=None)
|
|
105
|
+
else:
|
|
106
|
+
return Seq2SeqLMOutput(logits=outputs[0][:inputs_bsz], cross_attentions=outputs[1][:, :inputs_bsz])
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
|
|
110
|
+
"""
|
|
111
|
+
Whisper model for speech recognition and transcription optimized for RBLN NPU.
|
|
112
|
+
|
|
113
|
+
This model inherits from [`RBLNModel`]. It implements the methods to convert and run
|
|
114
|
+
pre-trained transformers based Whisper model on RBLN devices by:
|
|
115
|
+
|
|
116
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
117
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
118
|
+
|
|
119
|
+
Example (Short form):
|
|
120
|
+
```python
|
|
121
|
+
import torch
|
|
122
|
+
from transformers import AutoProcessor
|
|
123
|
+
from datasets import load_dataset
|
|
124
|
+
from optimum.rbln import RBLNWhisperForConditionalGeneration
|
|
125
|
+
|
|
126
|
+
# Load processor and dataset
|
|
127
|
+
model_id = "openai/whisper-tiny"
|
|
128
|
+
processor = AutoProcessor.from_pretrained(model_id)
|
|
129
|
+
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
130
|
+
|
|
131
|
+
# Prepare input features
|
|
132
|
+
input_features = processor(
|
|
133
|
+
ds[0]["audio"]["array"],
|
|
134
|
+
sampling_rate=ds[0]["audio"]["sampling_rate"],
|
|
135
|
+
return_tensors="pt"
|
|
136
|
+
).input_features
|
|
137
|
+
|
|
138
|
+
# Load and compile model (or load pre-compiled model)
|
|
139
|
+
model = RBLNWhisperForConditionalGeneration.from_pretrained(
|
|
140
|
+
model_id=model_id,
|
|
141
|
+
export=True,
|
|
142
|
+
rbln_batch_size=1
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Generate transcription
|
|
146
|
+
outputs = model.generate(input_features=input_features, return_timestamps=True)
|
|
147
|
+
transcription = processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
|
148
|
+
print(f"Transcription: {transcription}")
|
|
149
|
+
```
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
auto_model_class = AutoModelForSpeechSeq2Seq
|
|
153
|
+
main_input_name = "input_features"
|
|
154
|
+
_is_stateful = False
|
|
155
|
+
|
|
156
|
+
def __post_init__(self, **kwargs):
|
|
157
|
+
super().__post_init__(**kwargs)
|
|
158
|
+
|
|
159
|
+
self.batch_size = self.rbln_config.batch_size
|
|
160
|
+
self.dec_max_seq_len = self.rbln_config.dec_max_seq_len
|
|
161
|
+
self.rbln_token_timestamps = self.rbln_config.token_timestamps
|
|
162
|
+
self.use_attention_mask = self.rbln_config.use_attention_mask
|
|
163
|
+
|
|
164
|
+
self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
|
|
165
|
+
self.decoder = RBLNRuntimeDecoder(
|
|
166
|
+
runtime=self.model[1],
|
|
167
|
+
main_input_name="input_ids",
|
|
168
|
+
batch_size=self.batch_size,
|
|
169
|
+
dec_max_seq_len=self.dec_max_seq_len,
|
|
170
|
+
use_attention_mask=self.use_attention_mask,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# skip encoder & first decoder when language detected
|
|
174
|
+
self.is_language_detected = False
|
|
175
|
+
self.language_cross = None
|
|
176
|
+
|
|
177
|
+
# Used in GenerationMixin.generate()
|
|
178
|
+
# transformers/models/whisper/generation_whisper.py, line 505, in generate
|
|
179
|
+
# input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
|
180
|
+
self.model = WhisperModel(self.config)
|
|
181
|
+
self.pad_token_id = self.config.pad_token_id
|
|
182
|
+
|
|
183
|
+
def can_generate(self):
|
|
184
|
+
return True
|
|
185
|
+
|
|
186
|
+
def get_encoder(self):
|
|
187
|
+
return self.encoder
|
|
188
|
+
|
|
189
|
+
def get_decoder(self):
|
|
190
|
+
return self.decoder
|
|
191
|
+
|
|
192
|
+
def __getattr__(self, __name: str) -> Any:
|
|
193
|
+
def redirect(func):
|
|
194
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
195
|
+
|
|
196
|
+
val = getattr(WhisperForConditionalGeneration, __name)
|
|
197
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
198
|
+
return redirect(val)
|
|
199
|
+
return val
|
|
200
|
+
|
|
201
|
+
def _reorder_cache(self, past_key_values, beam_idx):
|
|
202
|
+
# TODO(jongho): implement
|
|
203
|
+
raise NotImplementedError
|
|
204
|
+
|
|
205
|
+
@classmethod
|
|
206
|
+
def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNWhisperForConditionalGenerationConfig):
|
|
207
|
+
return WhisperWrapper(
|
|
208
|
+
model,
|
|
209
|
+
use_attention_mask=rbln_config.use_attention_mask,
|
|
210
|
+
rbln_token_timestamps=rbln_config.token_timestamps,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
@classmethod
|
|
214
|
+
@torch.inference_mode()
|
|
215
|
+
def get_compiled_model(cls, model, rbln_config: RBLNWhisperForConditionalGenerationConfig):
|
|
216
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
217
|
+
|
|
218
|
+
enc_compile_config = rbln_config.compile_cfgs[0]
|
|
219
|
+
dec_compile_config = rbln_config.compile_cfgs[1]
|
|
220
|
+
|
|
221
|
+
context = CompileContext(use_weight_sharing=False)
|
|
222
|
+
|
|
223
|
+
enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
|
|
224
|
+
|
|
225
|
+
# Mark encoder's static tensors (cross kv states)
|
|
226
|
+
static_tensors = {}
|
|
227
|
+
for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
|
|
228
|
+
if "key_value_states" in name:
|
|
229
|
+
static_tensors[name] = tensor
|
|
230
|
+
context.mark_static_address(tensor)
|
|
231
|
+
|
|
232
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
233
|
+
|
|
234
|
+
# Mark decoder's static tensors (self kv states)
|
|
235
|
+
for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
|
|
236
|
+
if "key_value_states" in name:
|
|
237
|
+
context.mark_static_address(tensor)
|
|
238
|
+
|
|
239
|
+
compiled_encoder = cls.compile(
|
|
240
|
+
wrapped_model.encoder,
|
|
241
|
+
enc_compile_config,
|
|
242
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
243
|
+
device=rbln_config.device,
|
|
244
|
+
example_inputs=enc_example_inputs,
|
|
245
|
+
compile_context=context,
|
|
246
|
+
)
|
|
247
|
+
compiled_decoder = cls.compile(
|
|
248
|
+
wrapped_model.decoder,
|
|
249
|
+
dec_compile_config,
|
|
250
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
251
|
+
device=rbln_config.device,
|
|
252
|
+
example_inputs=dec_example_inputs,
|
|
253
|
+
compile_context=context,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
|
257
|
+
|
|
258
|
+
@classmethod
|
|
259
|
+
def _update_paged_attention_config(
|
|
260
|
+
cls, model_config: "PretrainedConfig", rbln_config: RBLNWhisperForConditionalGenerationConfig
|
|
261
|
+
):
|
|
262
|
+
rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
|
|
263
|
+
rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
|
|
264
|
+
|
|
265
|
+
if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
|
|
266
|
+
raise NotImplementedError(
|
|
267
|
+
f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
|
|
271
|
+
raise NotImplementedError(
|
|
272
|
+
f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
@classmethod
|
|
276
|
+
def _update_rbln_config(
|
|
277
|
+
cls,
|
|
278
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
279
|
+
model: Optional["PreTrainedModel"] = None,
|
|
280
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
281
|
+
rbln_config: Optional[RBLNWhisperForConditionalGenerationConfig] = None,
|
|
282
|
+
) -> RBLNWhisperForConditionalGenerationConfig:
|
|
283
|
+
expected_seq_len = model_config.max_source_positions * 2
|
|
284
|
+
num_mel_bins = model_config.num_mel_bins
|
|
285
|
+
rbln_config.enc_max_seq_len = model_config.max_source_positions
|
|
286
|
+
|
|
287
|
+
# 'whisper-large-v3-turbo' doesn't have 'max_length', but PretrainedConfig have default value for the key 'max_length'
|
|
288
|
+
rbln_config.dec_max_seq_len = getattr(model_config, "max_target_positions", None)
|
|
289
|
+
if rbln_config.dec_max_seq_len is None:
|
|
290
|
+
rbln_config.dec_max_seq_len = model_config.max_length
|
|
291
|
+
|
|
292
|
+
cls._update_paged_attention_config(model_config, rbln_config)
|
|
293
|
+
|
|
294
|
+
enc_input_info = [
|
|
295
|
+
("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
|
|
296
|
+
("block_tables", [1], "int16"),
|
|
297
|
+
(
|
|
298
|
+
"cross_key_value_states",
|
|
299
|
+
[
|
|
300
|
+
model_config.decoder_layers * 2,
|
|
301
|
+
rbln_config.batch_size,
|
|
302
|
+
model_config.decoder_attention_heads,
|
|
303
|
+
rbln_config.enc_max_seq_len,
|
|
304
|
+
model_config.d_model // model_config.decoder_attention_heads,
|
|
305
|
+
],
|
|
306
|
+
"float32",
|
|
307
|
+
),
|
|
308
|
+
]
|
|
309
|
+
|
|
310
|
+
dec_input_info = [
|
|
311
|
+
("decoder_input_ids", [rbln_config.batch_size, 1], "int64"),
|
|
312
|
+
("cache_position", [rbln_config.batch_size, 1], "int32"),
|
|
313
|
+
("block_tables", [rbln_config.batch_size, 1], "int16"),
|
|
314
|
+
]
|
|
315
|
+
dec_input_info.extend(
|
|
316
|
+
[
|
|
317
|
+
(
|
|
318
|
+
"cross_key_value_states",
|
|
319
|
+
[
|
|
320
|
+
model_config.decoder_layers * 2,
|
|
321
|
+
rbln_config.batch_size,
|
|
322
|
+
model_config.decoder_attention_heads,
|
|
323
|
+
rbln_config.enc_max_seq_len,
|
|
324
|
+
model_config.d_model // model_config.decoder_attention_heads,
|
|
325
|
+
],
|
|
326
|
+
"float32",
|
|
327
|
+
)
|
|
328
|
+
]
|
|
329
|
+
)
|
|
330
|
+
dec_input_info.extend(
|
|
331
|
+
[
|
|
332
|
+
(
|
|
333
|
+
f"self_key_value_states_{i}",
|
|
334
|
+
[
|
|
335
|
+
rbln_config.batch_size,
|
|
336
|
+
model_config.decoder_attention_heads,
|
|
337
|
+
rbln_config.dec_max_seq_len,
|
|
338
|
+
model_config.d_model // model_config.encoder_attention_heads,
|
|
339
|
+
],
|
|
340
|
+
"float32",
|
|
341
|
+
)
|
|
342
|
+
for i in range(model_config.decoder_layers * 2)
|
|
343
|
+
]
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
if rbln_config.use_attention_mask:
|
|
347
|
+
dec_input_info.insert(
|
|
348
|
+
1, ("decoder_attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
|
352
|
+
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
|
353
|
+
|
|
354
|
+
rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
|
|
355
|
+
|
|
356
|
+
return rbln_config
|
|
357
|
+
|
|
358
|
+
@classmethod
|
|
359
|
+
def _create_runtimes(
|
|
360
|
+
cls,
|
|
361
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
362
|
+
rbln_config: RBLNWhisperForConditionalGenerationConfig,
|
|
363
|
+
) -> List[rebel.Runtime]:
|
|
364
|
+
if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
|
|
365
|
+
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
|
366
|
+
|
|
367
|
+
return [
|
|
368
|
+
rebel.Runtime(
|
|
369
|
+
compiled_models[0],
|
|
370
|
+
tensor_type="pt",
|
|
371
|
+
device=rbln_config.device_map["encoder"],
|
|
372
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
373
|
+
timeout=rbln_config.timeout,
|
|
374
|
+
),
|
|
375
|
+
rebel.Runtime(
|
|
376
|
+
compiled_models[1],
|
|
377
|
+
tensor_type="pt",
|
|
378
|
+
device=rbln_config.device_map["decoder"],
|
|
379
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
380
|
+
timeout=rbln_config.timeout,
|
|
381
|
+
),
|
|
382
|
+
]
|
|
383
|
+
|
|
384
|
+
def prepare_inputs_for_generation(
|
|
385
|
+
self,
|
|
386
|
+
input_ids,
|
|
387
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
388
|
+
attention_mask: Optional[torch.Tensor] = None, # need for support transformers>=4.45.0
|
|
389
|
+
**kwargs,
|
|
390
|
+
):
|
|
391
|
+
return {
|
|
392
|
+
"input_ids": input_ids,
|
|
393
|
+
"cache_position": cache_position,
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L512
|
|
397
|
+
def _prepare_encoder_decoder_kwargs_for_generation(
|
|
398
|
+
self,
|
|
399
|
+
inputs_tensor: torch.Tensor,
|
|
400
|
+
model_kwargs,
|
|
401
|
+
model_input_name: Optional[str] = None,
|
|
402
|
+
generation_config: Optional["GenerationConfig"] = None,
|
|
403
|
+
**kwargs,
|
|
404
|
+
) -> Dict[str, Any]:
|
|
405
|
+
batch_size = inputs_tensor.shape[0]
|
|
406
|
+
n_pad_to_batch = self.batch_size - batch_size
|
|
407
|
+
if n_pad_to_batch > 0:
|
|
408
|
+
inputs_tensor = torch.nn.functional.pad(inputs_tensor, (0, 0, 0, 0, 0, n_pad_to_batch))
|
|
409
|
+
|
|
410
|
+
if not self.is_language_detected:
|
|
411
|
+
for b in range(inputs_tensor.shape[0]):
|
|
412
|
+
block_tables = torch.tensor([b], dtype=torch.int16)
|
|
413
|
+
model_kwargs["encoder_outputs"] = self.encoder(
|
|
414
|
+
input_features=inputs_tensor[b].unsqueeze(0), block_tables=block_tables
|
|
415
|
+
)
|
|
416
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
|
|
417
|
+
else:
|
|
418
|
+
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
|
|
419
|
+
|
|
420
|
+
return model_kwargs
|
|
421
|
+
|
|
422
|
+
def forward(
|
|
423
|
+
self,
|
|
424
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
425
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
426
|
+
input_features: Optional[torch.Tensor] = None,
|
|
427
|
+
decoder_input_ids: Optional[torch.Tensor] = None,
|
|
428
|
+
encoder_outputs: Optional[Seq2SeqLMOutput] = None,
|
|
429
|
+
**kwargs,
|
|
430
|
+
) -> Seq2SeqLMOutput:
|
|
431
|
+
# default decoder pass
|
|
432
|
+
if input_features is None and encoder_outputs is None:
|
|
433
|
+
cross_attentions = []
|
|
434
|
+
for step in cache_position:
|
|
435
|
+
# skip step 0 if language_detection has been processed
|
|
436
|
+
if step == 0 and self.is_language_detected:
|
|
437
|
+
cross_attentions.append(self.language_cross)
|
|
438
|
+
self.is_language_detected = False
|
|
439
|
+
else:
|
|
440
|
+
self.decoder_attention_mask[:, step] = 1
|
|
441
|
+
decoder_output = self.decoder(
|
|
442
|
+
decoder_input_ids=input_ids[:, step : step + 1].contiguous(),
|
|
443
|
+
decoder_attention_mask=self.decoder_attention_mask,
|
|
444
|
+
cache_position=torch.full((self.batch_size, 1), step, dtype=torch.int32),
|
|
445
|
+
)
|
|
446
|
+
cross_attentions.append(decoder_output.cross_attentions)
|
|
447
|
+
lm_logits = decoder_output.logits
|
|
448
|
+
|
|
449
|
+
if self.rbln_token_timestamps:
|
|
450
|
+
cross_attentions = torch.cat(cross_attentions, dim=-2)
|
|
451
|
+
else:
|
|
452
|
+
cross_attentions = None
|
|
453
|
+
|
|
454
|
+
return Seq2SeqLMOutput(logits=lm_logits, cross_attentions=cross_attentions)
|
|
455
|
+
|
|
456
|
+
# detect language pass
|
|
457
|
+
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/whisper/generation_whisper.py#L1442
|
|
458
|
+
else:
|
|
459
|
+
# for language auto detection (generate with language=None)
|
|
460
|
+
if encoder_outputs is None:
|
|
461
|
+
for b in range(input_features.shape[0]):
|
|
462
|
+
block_tables = torch.tensor([b], dtype=torch.int16)
|
|
463
|
+
self.encoder(input_features=input_features[b].unsqueeze(0), block_tables=block_tables)
|
|
464
|
+
|
|
465
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
|
|
466
|
+
self.is_language_detected = True
|
|
467
|
+
self.decoder_attention_mask[:, 0] = 1
|
|
468
|
+
decoder_output = self.decoder(
|
|
469
|
+
decoder_input_ids=decoder_input_ids.contiguous(),
|
|
470
|
+
decoder_attention_mask=self.decoder_attention_mask,
|
|
471
|
+
cache_position=torch.zeros([self.rbln_config.batch_size, 1], dtype=torch.int32),
|
|
472
|
+
)
|
|
473
|
+
lm_logits = decoder_output.logits
|
|
474
|
+
self.language_cross = decoder_output.cross_attentions
|
|
475
|
+
return Seq2SeqLMOutput(logits=lm_logits)
|