optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,13 +13,21 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
17
|
+
|
|
16
18
|
import torch
|
|
17
|
-
from transformers import
|
|
19
|
+
from transformers import AutoModelForCTC, Wav2Vec2Config, Wav2Vec2ForCTC
|
|
20
|
+
from transformers.modeling_outputs import CausalLMOutput
|
|
18
21
|
|
|
19
|
-
from
|
|
22
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
23
|
+
from ....modeling import RBLNModel
|
|
20
24
|
from .configuration_wav2vec2 import RBLNWav2Vec2ForCTCConfig
|
|
21
25
|
|
|
22
26
|
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
29
|
+
|
|
30
|
+
|
|
23
31
|
class _Wav2Vec2(torch.nn.Module):
|
|
24
32
|
def __init__(self, model: "Wav2Vec2ForCTC"):
|
|
25
33
|
super().__init__()
|
|
@@ -30,22 +38,67 @@ class _Wav2Vec2(torch.nn.Module):
|
|
|
30
38
|
return self.model.lm_head(output[0])
|
|
31
39
|
|
|
32
40
|
|
|
33
|
-
class RBLNWav2Vec2ForCTC(
|
|
41
|
+
class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
34
42
|
"""
|
|
35
43
|
Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
|
|
36
44
|
|
|
37
|
-
This model inherits from [`RBLNModelForMaskedLM`]. Check the superclass documentation for the generic methods the
|
|
38
|
-
library implements for all its model.
|
|
39
|
-
|
|
40
45
|
It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
|
|
46
|
+
|
|
41
47
|
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
42
48
|
- compiling the resulting graph using the RBLN compiler.
|
|
43
49
|
"""
|
|
44
50
|
|
|
45
51
|
main_input_name = "input_values"
|
|
46
|
-
auto_model_class =
|
|
52
|
+
auto_model_class = AutoModelForCTC
|
|
47
53
|
rbln_dtype = "float32"
|
|
48
54
|
|
|
49
55
|
@classmethod
|
|
50
|
-
def
|
|
56
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNWav2Vec2ForCTCConfig) -> torch.nn.Module:
|
|
51
57
|
return _Wav2Vec2(model).eval()
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def _update_rbln_config(
|
|
61
|
+
cls,
|
|
62
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
63
|
+
model: Optional["PreTrainedModel"] = None,
|
|
64
|
+
model_config: "Wav2Vec2Config" = None,
|
|
65
|
+
rbln_config: Optional[RBLNWav2Vec2ForCTCConfig] = None,
|
|
66
|
+
) -> RBLNWav2Vec2ForCTCConfig:
|
|
67
|
+
if rbln_config.max_seq_len is None:
|
|
68
|
+
for tokenizer in preprocessors:
|
|
69
|
+
if hasattr(tokenizer, "model_max_length"):
|
|
70
|
+
rbln_config.max_seq_len = tokenizer.model_max_length
|
|
71
|
+
break
|
|
72
|
+
if rbln_config.max_seq_len is None:
|
|
73
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
|
74
|
+
|
|
75
|
+
rbln_compile_config = RBLNCompileConfig(
|
|
76
|
+
input_info=[
|
|
77
|
+
(
|
|
78
|
+
"input_values",
|
|
79
|
+
[
|
|
80
|
+
rbln_config.batch_size,
|
|
81
|
+
rbln_config.max_seq_len,
|
|
82
|
+
],
|
|
83
|
+
"float32",
|
|
84
|
+
)
|
|
85
|
+
]
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
89
|
+
return rbln_config
|
|
90
|
+
|
|
91
|
+
def forward(
|
|
92
|
+
self, input_values: torch.Tensor, return_dict: Optional[bool] = None, **kwargs
|
|
93
|
+
) -> Union[CausalLMOutput, tuple]:
|
|
94
|
+
"""
|
|
95
|
+
Forward pass for the RBLN-optimized Wav2Vec2 model for Connectionist Temporal Classification (CTC).
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
input_values (torch.FloatTensor of shape (batch_size, sequence_length)): Float values of input raw speech waveform. Values can be obtained by loading a .flac or .wav audio file into an array of type List[float] or a numpy.ndarray, e.g. via the soundfile library (pip install soundfile). To prepare the array into input_values, the AutoProcessor should be used for padding and conversion into a tensor of type torch.FloatTensor.
|
|
99
|
+
return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CausalLMOutput object.
|
|
103
|
+
"""
|
|
104
|
+
return super().forward(input_values=input_values, return_dict=return_dict, **kwargs)
|
|
@@ -12,9 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any
|
|
16
|
-
|
|
17
|
-
import rebel
|
|
15
|
+
from typing import Any
|
|
18
16
|
|
|
19
17
|
from ....configuration_utils import RBLNModelConfig
|
|
20
18
|
from ....utils.logging import get_logger
|
|
@@ -38,17 +36,22 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
|
|
|
38
36
|
use_attention_mask: bool = None,
|
|
39
37
|
enc_max_seq_len: int = None,
|
|
40
38
|
dec_max_seq_len: int = None,
|
|
41
|
-
|
|
39
|
+
kvcache_num_blocks: int = None,
|
|
40
|
+
kvcache_block_size: int = None,
|
|
41
|
+
**kwargs: Any,
|
|
42
42
|
):
|
|
43
43
|
"""
|
|
44
44
|
Args:
|
|
45
45
|
batch_size (int, optional): The batch size for inference. Defaults to 1.
|
|
46
46
|
token_timestamps (bool, optional): Whether to output token timestamps during generation. Defaults to False.
|
|
47
47
|
use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
|
|
48
|
-
set to True for RBLN-CA02 devices.
|
|
49
48
|
enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
|
|
50
49
|
dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
|
|
51
|
-
|
|
50
|
+
kvcache_num_blocks (int, optional): The total number of blocks to allocate for the
|
|
51
|
+
PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
|
|
52
|
+
kvcache_block_size (int, optional): Sets the size (in number of tokens) of each block
|
|
53
|
+
in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
|
|
54
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
52
55
|
|
|
53
56
|
Raises:
|
|
54
57
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -64,10 +67,6 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
|
|
|
64
67
|
self.dec_max_seq_len = dec_max_seq_len
|
|
65
68
|
|
|
66
69
|
self.use_attention_mask = use_attention_mask
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
|
|
71
|
-
self.use_attention_mask = True
|
|
72
|
-
else:
|
|
73
|
-
self.use_attention_mask = self.use_attention_mask or False
|
|
70
|
+
self.use_attention_mask = self.use_attention_mask or False
|
|
71
|
+
self.kvcache_num_blocks = kvcache_num_blocks
|
|
72
|
+
self.kvcache_block_size = kvcache_block_size
|
|
@@ -31,22 +31,73 @@ Generation utilities for Whisper.
|
|
|
31
31
|
Modified from `transformers.models.whisper.generation_whisper.py`
|
|
32
32
|
"""
|
|
33
33
|
|
|
34
|
+
from typing import Any, Dict, Optional, Union
|
|
35
|
+
|
|
34
36
|
import torch
|
|
35
37
|
import transformers
|
|
36
38
|
from packaging import version
|
|
37
39
|
from transformers import GenerationMixin
|
|
40
|
+
from transformers.generation.configuration_utils import GenerationConfig
|
|
41
|
+
from transformers.modeling_outputs import ModelOutput
|
|
38
42
|
from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
|
|
39
43
|
|
|
40
44
|
|
|
41
45
|
class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
46
|
+
def generate(
|
|
47
|
+
self,
|
|
48
|
+
input_features: Optional[torch.Tensor] = None,
|
|
49
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
50
|
+
generation_config: Optional[GenerationConfig] = None,
|
|
51
|
+
return_segments: Optional[bool] = None,
|
|
52
|
+
return_timestamps: Optional[bool] = None,
|
|
53
|
+
return_token_timestamps: Optional[bool] = None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
) -> Union[ModelOutput, Dict[str, Any], torch.LongTensor]:
|
|
56
|
+
"""
|
|
57
|
+
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
58
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate) for more details.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
input_features(torch.Tensor, optional): The input features to the model.
|
|
62
|
+
attention_mask(torch.Tensor, optional): Attention mask needs to be passed when doing long-form transcription using a batch size > 1.
|
|
63
|
+
generation_config(GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
|
|
64
|
+
If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
|
|
65
|
+
Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
|
|
66
|
+
return_segments(bool, optional): Whether to return segments.
|
|
67
|
+
return_timestamps(bool, optional): Whether to return the timestamps with the text. For audios longer than 30 seconds, it is necessary to set return_timestamps=True.
|
|
68
|
+
return_token_timestamps(bool, optional): Whether to return token timestamps.
|
|
69
|
+
kwargs(dict[str, Any], optional): Additional arguments passed to the generate function.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
|
|
73
|
+
"""
|
|
74
|
+
if kwargs.get("num_beams", None) is not None:
|
|
75
|
+
if kwargs.get("num_beams") != 1:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"Beam search is not supported in RBLNWhisperGenerationMixin. "
|
|
78
|
+
"Received num_beams={num_beams}, but only num_beams=1 is allowed. "
|
|
79
|
+
"Please set num_beams=1 for greedy search or adjust your configuration."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return super().generate(
|
|
83
|
+
input_features,
|
|
84
|
+
attention_mask=attention_mask,
|
|
85
|
+
generation_config=generation_config,
|
|
86
|
+
return_segments=return_segments,
|
|
87
|
+
return_timestamps=return_timestamps,
|
|
88
|
+
return_token_timestamps=return_token_timestamps,
|
|
89
|
+
**kwargs,
|
|
90
|
+
)
|
|
47
91
|
|
|
48
92
|
def _postprocess_outputs(
|
|
49
|
-
self,
|
|
93
|
+
self,
|
|
94
|
+
seek_outputs,
|
|
95
|
+
decoder_input_ids,
|
|
96
|
+
return_token_timestamps,
|
|
97
|
+
generation_config,
|
|
98
|
+
is_shortform,
|
|
99
|
+
seek,
|
|
100
|
+
batch_idx_map,
|
|
50
101
|
):
|
|
51
102
|
# remove all previously passed decoder input ids
|
|
52
103
|
# should happen only if it is the first generated segment
|
|
@@ -64,6 +115,11 @@ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
|
|
|
64
115
|
|
|
65
116
|
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
|
66
117
|
num_frames = getattr(generation_config, "num_frames", None)
|
|
118
|
+
|
|
119
|
+
if num_frames is not None:
|
|
120
|
+
num_frames = num_frames - seek
|
|
121
|
+
num_frames = num_frames[batch_idx_map]
|
|
122
|
+
|
|
67
123
|
if version.parse(transformers.__version__) >= version.parse("4.46.0"):
|
|
68
124
|
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
|
69
125
|
seek_outputs,
|
|
@@ -46,7 +46,7 @@ if TYPE_CHECKING:
|
|
|
46
46
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
47
47
|
mandatory_members = ["main_input_name"]
|
|
48
48
|
|
|
49
|
-
def forward(self, *args: List[torch.Tensor], **kwargs:
|
|
49
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
|
|
50
50
|
output = super().forward(*args, **kwargs)
|
|
51
51
|
return BaseModelOutput(last_hidden_state=output)
|
|
52
52
|
|
|
@@ -73,6 +73,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
|
73
73
|
decoder_input_ids: torch.Tensor = None,
|
|
74
74
|
decoder_attention_mask: torch.Tensor = None,
|
|
75
75
|
cache_position: torch.Tensor = None,
|
|
76
|
+
block_tables: torch.Tensor = None,
|
|
76
77
|
):
|
|
77
78
|
inputs_bsz = decoder_input_ids.shape[0]
|
|
78
79
|
padded_bsz = self.batch_size - inputs_bsz
|
|
@@ -89,11 +90,14 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
|
89
90
|
)
|
|
90
91
|
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
|
91
92
|
|
|
93
|
+
if block_tables is None:
|
|
94
|
+
block_tables = self.default_block_tables
|
|
95
|
+
|
|
92
96
|
outputs = super().forward(
|
|
93
97
|
decoder_input_ids,
|
|
94
98
|
decoder_attention_mask if self.use_attention_mask else None,
|
|
95
99
|
cache_position,
|
|
96
|
-
block_tables=
|
|
100
|
+
block_tables=block_tables,
|
|
97
101
|
)
|
|
98
102
|
|
|
99
103
|
if isinstance(outputs, torch.Tensor):
|
|
@@ -108,6 +112,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
108
112
|
|
|
109
113
|
This model inherits from [`RBLNModel`]. It implements the methods to convert and run
|
|
110
114
|
pre-trained transformers based Whisper model on RBLN devices by:
|
|
115
|
+
|
|
111
116
|
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
112
117
|
- compiling the resulting graph using the RBLN compiler.
|
|
113
118
|
|
|
@@ -145,7 +150,8 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
145
150
|
"""
|
|
146
151
|
|
|
147
152
|
auto_model_class = AutoModelForSpeechSeq2Seq
|
|
148
|
-
main_input_name = "
|
|
153
|
+
main_input_name = "input_features"
|
|
154
|
+
_is_stateful = False
|
|
149
155
|
|
|
150
156
|
def __post_init__(self, **kwargs):
|
|
151
157
|
super().__post_init__(**kwargs)
|
|
@@ -197,7 +203,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
197
203
|
raise NotImplementedError
|
|
198
204
|
|
|
199
205
|
@classmethod
|
|
200
|
-
def
|
|
206
|
+
def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNWhisperForConditionalGenerationConfig):
|
|
201
207
|
return WhisperWrapper(
|
|
202
208
|
model,
|
|
203
209
|
use_attention_mask=rbln_config.use_attention_mask,
|
|
@@ -207,7 +213,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
207
213
|
@classmethod
|
|
208
214
|
@torch.inference_mode()
|
|
209
215
|
def get_compiled_model(cls, model, rbln_config: RBLNWhisperForConditionalGenerationConfig):
|
|
210
|
-
wrapped_model = cls.
|
|
216
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
211
217
|
|
|
212
218
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
|
213
219
|
dec_compile_config = rbln_config.compile_cfgs[1]
|
|
@@ -249,6 +255,23 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
249
255
|
|
|
250
256
|
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
|
251
257
|
|
|
258
|
+
@classmethod
|
|
259
|
+
def _update_paged_attention_config(
|
|
260
|
+
cls, model_config: "PretrainedConfig", rbln_config: RBLNWhisperForConditionalGenerationConfig
|
|
261
|
+
):
|
|
262
|
+
rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
|
|
263
|
+
rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
|
|
264
|
+
|
|
265
|
+
if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
|
|
266
|
+
raise NotImplementedError(
|
|
267
|
+
f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
|
|
271
|
+
raise NotImplementedError(
|
|
272
|
+
f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
|
|
273
|
+
)
|
|
274
|
+
|
|
252
275
|
@classmethod
|
|
253
276
|
def _update_rbln_config(
|
|
254
277
|
cls,
|
|
@@ -266,6 +289,8 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
266
289
|
if rbln_config.dec_max_seq_len is None:
|
|
267
290
|
rbln_config.dec_max_seq_len = model_config.max_length
|
|
268
291
|
|
|
292
|
+
cls._update_paged_attention_config(model_config, rbln_config)
|
|
293
|
+
|
|
269
294
|
enc_input_info = [
|
|
270
295
|
("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
|
|
271
296
|
("block_tables", [1], "int16"),
|
|
@@ -345,12 +370,14 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
345
370
|
tensor_type="pt",
|
|
346
371
|
device=rbln_config.device_map["encoder"],
|
|
347
372
|
activate_profiler=rbln_config.activate_profiler,
|
|
373
|
+
timeout=rbln_config.timeout,
|
|
348
374
|
),
|
|
349
375
|
rebel.Runtime(
|
|
350
376
|
compiled_models[1],
|
|
351
377
|
tensor_type="pt",
|
|
352
378
|
device=rbln_config.device_map["decoder"],
|
|
353
379
|
activate_profiler=rbln_config.activate_profiler,
|
|
380
|
+
timeout=rbln_config.timeout,
|
|
354
381
|
),
|
|
355
382
|
]
|
|
356
383
|
|
|
@@ -12,14 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .configuration_xlm_roberta import
|
|
16
|
-
|
|
17
|
-
RBLNXLMRobertaModelConfig,
|
|
18
|
-
)
|
|
19
|
-
from .modeling_xlm_roberta import (
|
|
20
|
-
RBLNXLMRobertaForSequenceClassification,
|
|
21
|
-
RBLNXLMRobertaModel,
|
|
22
|
-
)
|
|
15
|
+
from .configuration_xlm_roberta import RBLNXLMRobertaForSequenceClassificationConfig, RBLNXLMRobertaModelConfig
|
|
16
|
+
from .modeling_xlm_roberta import RBLNXLMRobertaForSequenceClassification, RBLNXLMRobertaModel
|
|
23
17
|
|
|
24
18
|
|
|
25
19
|
__all__ = [
|
|
@@ -12,6 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from typing import Optional, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions, SequenceClassifierOutput
|
|
19
|
+
|
|
15
20
|
from ...modeling_generic import RBLNModelForSequenceClassification, RBLNTransformerEncoderForFeatureExtraction
|
|
16
21
|
|
|
17
22
|
|
|
@@ -20,6 +25,25 @@ class RBLNXLMRobertaModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
|
20
25
|
XLM-RoBERTa base model optimized for RBLN NPU.
|
|
21
26
|
"""
|
|
22
27
|
|
|
28
|
+
def forward(
|
|
29
|
+
self,
|
|
30
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
31
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
32
|
+
**kwargs,
|
|
33
|
+
) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, tuple]:
|
|
34
|
+
"""
|
|
35
|
+
Forward pass for the RBLN-optimized XLM-RoBERTa base model.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
39
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPoolingAndCrossAttentions object.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
return super().forward(input_ids, attention_mask, **kwargs)
|
|
46
|
+
|
|
23
47
|
|
|
24
48
|
class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
|
25
49
|
"""
|
|
@@ -27,3 +51,22 @@ class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification
|
|
|
27
51
|
"""
|
|
28
52
|
|
|
29
53
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
54
|
+
|
|
55
|
+
def forward(
|
|
56
|
+
self,
|
|
57
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
58
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
|
59
|
+
**kwargs,
|
|
60
|
+
) -> Union[SequenceClassifierOutput, tuple]:
|
|
61
|
+
"""
|
|
62
|
+
Forward pass for the RBLN-optimized XLM-RoBERTa model for sequence classification.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
66
|
+
attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SequenceClassifierOutput object.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
return super().forward(input_ids, attention_mask, **kwargs)
|