optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +173 -35
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +816 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +111 -137
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +56 -71
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
- optimum/rbln/modeling.py +66 -40
- optimum/rbln/modeling_base.py +111 -86
- optimum/rbln/ops/__init__.py +4 -7
- optimum/rbln/ops/attn.py +271 -205
- optimum/rbln/ops/flash_attn.py +161 -67
- optimum/rbln/ops/kv_cache_update.py +4 -40
- optimum/rbln/ops/linear.py +25 -0
- optimum/rbln/transformers/__init__.py +97 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +120 -32
- optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
- optimum/rbln/transformers/models/bart/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
- optimum/rbln/transformers/models/t5/__init__.py +2 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
- optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +2 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/hub.py +2 -2
- optimum/rbln/utils/import_utils.py +23 -6
- optimum/rbln/utils/model_utils.py +4 -4
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +36 -44
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
- optimum_rbln-0.7.4.dist-info/RECORD +169 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -18,19 +18,14 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
18
18
|
import rebel
|
19
19
|
import torch
|
20
20
|
from rebel.compile_context import CompileContext
|
21
|
-
from transformers import
|
22
|
-
AutoModelForSpeechSeq2Seq,
|
23
|
-
AutoProcessor,
|
24
|
-
PretrainedConfig,
|
25
|
-
WhisperForConditionalGeneration,
|
26
|
-
WhisperModel,
|
27
|
-
)
|
21
|
+
from transformers import AutoModelForSpeechSeq2Seq, WhisperForConditionalGeneration, WhisperModel
|
28
22
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
29
23
|
|
24
|
+
from ....configuration_utils import RBLNCompileConfig
|
30
25
|
from ....modeling import RBLNModel
|
31
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
26
|
from ....utils.logging import get_logger
|
33
27
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
28
|
+
from .configuration_whisper import RBLNWhisperForConditionalGenerationConfig
|
34
29
|
from .generation_whisper import RBLNWhisperGenerationMixin
|
35
30
|
from .whisper_architecture import WhisperWrapper
|
36
31
|
|
@@ -38,29 +33,41 @@ from .whisper_architecture import WhisperWrapper
|
|
38
33
|
logger = get_logger(__name__)
|
39
34
|
|
40
35
|
if TYPE_CHECKING:
|
41
|
-
from transformers import
|
36
|
+
from transformers import (
|
37
|
+
AutoFeatureExtractor,
|
38
|
+
AutoProcessor,
|
39
|
+
AutoTokenizer,
|
40
|
+
GenerationConfig,
|
41
|
+
PretrainedConfig,
|
42
|
+
PreTrainedModel,
|
43
|
+
)
|
42
44
|
|
43
45
|
|
44
46
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
45
47
|
mandatory_members = ["main_input_name"]
|
46
48
|
|
47
|
-
def forward(self,
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
n_pad_to_batch = self.batch_size - input_features.shape[0]
|
52
|
-
if n_pad_to_batch > 0:
|
53
|
-
input_features = torch.nn.functional.pad(input_features, (0, 0, 0, 0, 0, n_pad_to_batch))
|
54
|
-
|
55
|
-
_ = super().forward(input_features=input_features)
|
56
|
-
|
57
|
-
# dummy output for generation
|
58
|
-
return BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
|
49
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
50
|
+
output = super().forward(*args, **kwargs)
|
51
|
+
return BaseModelOutput(last_hidden_state=output)
|
59
52
|
|
60
53
|
|
61
54
|
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
62
55
|
mandatory_members = ["main_input_name"]
|
63
56
|
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
runtime: rebel.Runtime,
|
60
|
+
batch_size: int,
|
61
|
+
dec_max_seq_len: int,
|
62
|
+
use_attention_mask: Optional[bool] = None,
|
63
|
+
**kwargs: Any,
|
64
|
+
) -> None:
|
65
|
+
super().__init__(runtime, **kwargs)
|
66
|
+
self.batch_size = batch_size
|
67
|
+
self.dec_max_seq_len = dec_max_seq_len
|
68
|
+
self.use_attention_mask = use_attention_mask
|
69
|
+
self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
|
70
|
+
|
64
71
|
def forward(
|
65
72
|
self,
|
66
73
|
decoder_input_ids: torch.Tensor = None,
|
@@ -69,13 +76,24 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
69
76
|
):
|
70
77
|
inputs_bsz = decoder_input_ids.shape[0]
|
71
78
|
padded_bsz = self.batch_size - inputs_bsz
|
79
|
+
|
72
80
|
if padded_bsz > 0:
|
73
81
|
decoder_input_ids = torch.nn.functional.pad(decoder_input_ids, (0, 0, 0, padded_bsz))
|
74
82
|
|
83
|
+
if self.use_attention_mask:
|
84
|
+
for b_idx in range(self.batch_size):
|
85
|
+
decoding_step = cache_position[b_idx].item()
|
86
|
+
if not (0 <= decoding_step < self.dec_max_seq_len):
|
87
|
+
raise ValueError(
|
88
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
89
|
+
)
|
90
|
+
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
91
|
+
|
75
92
|
outputs = super().forward(
|
76
|
-
decoder_input_ids
|
77
|
-
decoder_attention_mask
|
78
|
-
cache_position
|
93
|
+
decoder_input_ids,
|
94
|
+
decoder_attention_mask if self.use_attention_mask else None,
|
95
|
+
cache_position,
|
96
|
+
block_tables=self.default_block_tables,
|
79
97
|
)
|
80
98
|
|
81
99
|
if isinstance(outputs, torch.Tensor):
|
@@ -101,15 +119,18 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
101
119
|
def __post_init__(self, **kwargs):
|
102
120
|
super().__post_init__(**kwargs)
|
103
121
|
|
104
|
-
self.batch_size = self.rbln_config.
|
105
|
-
self.dec_max_seq_len = self.rbln_config.
|
106
|
-
self.rbln_token_timestamps = self.rbln_config.
|
122
|
+
self.batch_size = self.rbln_config.batch_size
|
123
|
+
self.dec_max_seq_len = self.rbln_config.dec_max_seq_len
|
124
|
+
self.rbln_token_timestamps = self.rbln_config.token_timestamps
|
125
|
+
self.use_attention_mask = self.rbln_config.use_attention_mask
|
107
126
|
|
108
|
-
self.encoder = RBLNRuntimeEncoder(
|
109
|
-
runtime=self.model[0], main_input_name="input_features", batch_size=self.batch_size
|
110
|
-
)
|
127
|
+
self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
|
111
128
|
self.decoder = RBLNRuntimeDecoder(
|
112
|
-
runtime=self.model[1],
|
129
|
+
runtime=self.model[1],
|
130
|
+
main_input_name="input_ids",
|
131
|
+
batch_size=self.batch_size,
|
132
|
+
dec_max_seq_len=self.dec_max_seq_len,
|
133
|
+
use_attention_mask=self.use_attention_mask,
|
113
134
|
)
|
114
135
|
|
115
136
|
# skip encoder & first decoder when language detected
|
@@ -150,13 +171,16 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
150
171
|
raise NotImplementedError
|
151
172
|
|
152
173
|
@classmethod
|
153
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config:
|
154
|
-
|
155
|
-
|
174
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNWhisperForConditionalGenerationConfig):
|
175
|
+
return WhisperWrapper(
|
176
|
+
model,
|
177
|
+
use_attention_mask=rbln_config.use_attention_mask,
|
178
|
+
rbln_token_timestamps=rbln_config.token_timestamps,
|
179
|
+
)
|
156
180
|
|
157
181
|
@classmethod
|
158
182
|
@torch.inference_mode()
|
159
|
-
def get_compiled_model(cls, model, rbln_config:
|
183
|
+
def get_compiled_model(cls, model, rbln_config: RBLNWhisperForConditionalGenerationConfig):
|
160
184
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
161
185
|
|
162
186
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
@@ -196,47 +220,42 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
196
220
|
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
197
221
|
|
198
222
|
@classmethod
|
199
|
-
def
|
223
|
+
def _update_rbln_config(
|
200
224
|
cls,
|
201
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
rbln_token_timestamps = rbln_kwargs.get("token_timestamps", False)
|
207
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
208
|
-
|
225
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
226
|
+
model: Optional["PreTrainedModel"] = None,
|
227
|
+
model_config: Optional["PretrainedConfig"] = None,
|
228
|
+
rbln_config: Optional[RBLNWhisperForConditionalGenerationConfig] = None,
|
229
|
+
) -> RBLNWhisperForConditionalGenerationConfig:
|
209
230
|
expected_seq_len = model_config.max_source_positions * 2
|
210
231
|
num_mel_bins = model_config.num_mel_bins
|
211
|
-
enc_max_seq_len = model_config.max_source_positions
|
232
|
+
rbln_config.enc_max_seq_len = model_config.max_source_positions
|
212
233
|
|
213
234
|
# 'whisper-large-v3-turbo' doesn't have 'max_length', but PretrainedConfig have default value for the key 'max_length'
|
214
|
-
|
215
|
-
if
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
]
|
234
|
-
)
|
235
|
+
rbln_config.dec_max_seq_len = getattr(model_config, "max_target_positions", None)
|
236
|
+
if rbln_config.dec_max_seq_len is None:
|
237
|
+
rbln_config.dec_max_seq_len = model_config.max_length
|
238
|
+
|
239
|
+
enc_input_info = [
|
240
|
+
("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
|
241
|
+
("block_tables", [1], "int16"),
|
242
|
+
(
|
243
|
+
"cross_key_value_states",
|
244
|
+
[
|
245
|
+
model_config.decoder_layers * 2,
|
246
|
+
rbln_config.batch_size,
|
247
|
+
model_config.decoder_attention_heads,
|
248
|
+
rbln_config.enc_max_seq_len,
|
249
|
+
model_config.d_model // model_config.decoder_attention_heads,
|
250
|
+
],
|
251
|
+
"float32",
|
252
|
+
),
|
253
|
+
]
|
235
254
|
|
236
255
|
dec_input_info = [
|
237
|
-
("decoder_input_ids", [
|
238
|
-
("
|
239
|
-
("
|
256
|
+
("decoder_input_ids", [rbln_config.batch_size, 1], "int64"),
|
257
|
+
("cache_position", [rbln_config.batch_size, 1], "int32"),
|
258
|
+
("block_tables", [rbln_config.batch_size, 1], "int16"),
|
240
259
|
]
|
241
260
|
dec_input_info.extend(
|
242
261
|
[
|
@@ -244,9 +263,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
244
263
|
"cross_key_value_states",
|
245
264
|
[
|
246
265
|
model_config.decoder_layers * 2,
|
247
|
-
|
266
|
+
rbln_config.batch_size,
|
248
267
|
model_config.decoder_attention_heads,
|
249
|
-
enc_max_seq_len,
|
268
|
+
rbln_config.enc_max_seq_len,
|
250
269
|
model_config.d_model // model_config.decoder_attention_heads,
|
251
270
|
],
|
252
271
|
"float32",
|
@@ -258,9 +277,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
258
277
|
(
|
259
278
|
f"self_key_value_states_{i}",
|
260
279
|
[
|
261
|
-
|
280
|
+
rbln_config.batch_size,
|
262
281
|
model_config.decoder_attention_heads,
|
263
|
-
|
282
|
+
rbln_config.dec_max_seq_len,
|
264
283
|
model_config.d_model // model_config.encoder_attention_heads,
|
265
284
|
],
|
266
285
|
"float32",
|
@@ -269,22 +288,15 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
269
288
|
]
|
270
289
|
)
|
271
290
|
|
291
|
+
if rbln_config.use_attention_mask:
|
292
|
+
dec_input_info.insert(
|
293
|
+
1, ("decoder_attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
|
294
|
+
)
|
295
|
+
|
272
296
|
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
273
297
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
274
298
|
|
275
|
-
rbln_config
|
276
|
-
rbln_cls=cls.__name__,
|
277
|
-
compile_cfgs=[enc_compile_config, dec_compile_config],
|
278
|
-
rbln_kwargs=rbln_kwargs,
|
279
|
-
)
|
280
|
-
|
281
|
-
rbln_config.model_cfg.update(
|
282
|
-
{
|
283
|
-
"batch_size": rbln_batch_size,
|
284
|
-
"dec_max_seq_len": rbln_dec_max_seq_len,
|
285
|
-
"token_timestamps": rbln_token_timestamps,
|
286
|
-
}
|
287
|
-
)
|
299
|
+
rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
|
288
300
|
|
289
301
|
return rbln_config
|
290
302
|
|
@@ -292,18 +304,23 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
292
304
|
def _create_runtimes(
|
293
305
|
cls,
|
294
306
|
compiled_models: List[rebel.RBLNCompiledModel],
|
295
|
-
|
296
|
-
activate_profiler: Optional[bool] = None,
|
307
|
+
rbln_config: RBLNWhisperForConditionalGenerationConfig,
|
297
308
|
) -> List[rebel.Runtime]:
|
298
|
-
if any(model_name not in
|
309
|
+
if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
|
299
310
|
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
300
311
|
|
301
312
|
return [
|
302
|
-
|
303
|
-
|
313
|
+
rebel.Runtime(
|
314
|
+
compiled_models[0],
|
315
|
+
tensor_type="pt",
|
316
|
+
device=rbln_config.device_map["encoder"],
|
317
|
+
activate_profiler=rbln_config.activate_profiler,
|
304
318
|
),
|
305
|
-
|
306
|
-
|
319
|
+
rebel.Runtime(
|
320
|
+
compiled_models[1],
|
321
|
+
tensor_type="pt",
|
322
|
+
device=rbln_config.device_map["decoder"],
|
323
|
+
activate_profiler=rbln_config.activate_profiler,
|
307
324
|
),
|
308
325
|
]
|
309
326
|
|
@@ -327,11 +344,25 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
327
344
|
|
328
345
|
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L512
|
329
346
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
330
|
-
self,
|
347
|
+
self,
|
348
|
+
inputs_tensor: torch.Tensor,
|
349
|
+
model_kwargs,
|
350
|
+
model_input_name: Optional[str] = None,
|
351
|
+
generation_config: Optional["GenerationConfig"] = None,
|
352
|
+
**kwargs,
|
331
353
|
) -> Dict[str, Any]:
|
354
|
+
batch_size = inputs_tensor.shape[0]
|
355
|
+
n_pad_to_batch = self.batch_size - batch_size
|
356
|
+
if n_pad_to_batch > 0:
|
357
|
+
inputs_tensor = torch.nn.functional.pad(inputs_tensor, (0, 0, 0, 0, 0, n_pad_to_batch))
|
358
|
+
|
332
359
|
if not self.is_language_detected:
|
333
|
-
|
334
|
-
|
360
|
+
for b in range(inputs_tensor.shape[0]):
|
361
|
+
block_tables = torch.tensor([b], dtype=torch.int16)
|
362
|
+
model_kwargs["encoder_outputs"] = self.encoder(
|
363
|
+
input_features=inputs_tensor[b].unsqueeze(0), block_tables=block_tables
|
364
|
+
)
|
365
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
|
335
366
|
else:
|
336
367
|
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
|
337
368
|
|
@@ -359,7 +390,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
359
390
|
decoder_output = self.decoder(
|
360
391
|
decoder_input_ids=input_ids[:, step : step + 1].contiguous(),
|
361
392
|
decoder_attention_mask=self.decoder_attention_mask,
|
362
|
-
cache_position=
|
393
|
+
cache_position=torch.full((self.batch_size, 1), step, dtype=torch.int32),
|
363
394
|
)
|
364
395
|
cross_attentions.append(decoder_output.cross_attentions)
|
365
396
|
lm_logits = decoder_output.logits
|
@@ -374,15 +405,19 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
374
405
|
# detect language pass
|
375
406
|
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/whisper/generation_whisper.py#L1442
|
376
407
|
else:
|
408
|
+
# for language auto detection (generate with language=None)
|
377
409
|
if encoder_outputs is None:
|
378
|
-
|
379
|
-
|
410
|
+
for b in range(input_features.shape[0]):
|
411
|
+
block_tables = torch.tensor([b], dtype=torch.int16)
|
412
|
+
self.encoder(input_features=input_features[b].unsqueeze(0), block_tables=block_tables)
|
413
|
+
|
414
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
|
380
415
|
self.is_language_detected = True
|
381
416
|
self.decoder_attention_mask[:, 0] = 1
|
382
417
|
decoder_output = self.decoder(
|
383
418
|
decoder_input_ids=decoder_input_ids.contiguous(),
|
384
419
|
decoder_attention_mask=self.decoder_attention_mask,
|
385
|
-
cache_position=torch.zeros([], dtype=torch.int32),
|
420
|
+
cache_position=torch.zeros([self.rbln_config.batch_size, 1], dtype=torch.int32),
|
386
421
|
)
|
387
422
|
lm_logits = decoder_output.logits
|
388
423
|
self.language_cross = decoder_output.cross_attentions
|
@@ -16,27 +16,19 @@ from typing import Optional, Tuple, Union
|
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from torch import nn
|
19
|
-
from transformers.
|
20
|
-
_prepare_4d_causal_attention_mask,
|
21
|
-
)
|
22
|
-
from transformers.modeling_outputs import (
|
23
|
-
BaseModelOutput,
|
24
|
-
Seq2SeqLMOutput,
|
25
|
-
)
|
19
|
+
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
26
20
|
from transformers.utils import logging
|
27
21
|
|
28
|
-
from ....ops import register_rbln_custom_add_softmax_attention, register_rbln_custom_cache_update
|
29
|
-
|
30
22
|
|
31
23
|
logger = logging.get_logger(__name__)
|
32
24
|
|
33
25
|
|
34
26
|
class WhisperWrapper:
|
35
|
-
def __init__(self, model, rbln_token_timestamps):
|
36
|
-
register_rbln_custom_cache_update()
|
37
|
-
register_rbln_custom_add_softmax_attention()
|
27
|
+
def __init__(self, model, use_attention_mask, rbln_token_timestamps):
|
38
28
|
self.encoder = WhisperEncoderWrapper(model)
|
39
|
-
self.decoder = WhisperDecoderWrapper(
|
29
|
+
self.decoder = WhisperDecoderWrapper(
|
30
|
+
model, use_attention_mask=use_attention_mask, output_attentions=rbln_token_timestamps
|
31
|
+
)
|
40
32
|
|
41
33
|
|
42
34
|
class WhisperEncoderWrapper(torch.nn.Module):
|
@@ -57,6 +49,7 @@ class WhisperEncoderWrapper(torch.nn.Module):
|
|
57
49
|
def forward(
|
58
50
|
self,
|
59
51
|
input_features: Optional[torch.LongTensor],
|
52
|
+
b_idx: torch.Tensor,
|
60
53
|
cross_key_values: torch.Tensor,
|
61
54
|
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
62
55
|
# 1. get encoder last_hidden_states
|
@@ -76,21 +69,31 @@ class WhisperEncoderWrapper(torch.nn.Module):
|
|
76
69
|
cross_kv = torch.stack(cross_kv, dim=0)
|
77
70
|
|
78
71
|
# 3. update cross_attention's past_key_value to the device-dram for optimization.
|
79
|
-
|
80
|
-
|
81
|
-
|
72
|
+
batch_axis = torch.tensor(1, dtype=torch.int16)
|
73
|
+
cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
|
74
|
+
cross_key_values, cross_kv, b_idx[0], batch_axis
|
75
|
+
)
|
82
76
|
|
83
|
-
return
|
77
|
+
return cross_key_values
|
84
78
|
|
85
79
|
|
86
80
|
class WhisperDecoderWrapper(torch.nn.Module):
|
87
|
-
def __init__(self, model, output_attentions: bool = False):
|
81
|
+
def __init__(self, model, use_attention_mask: bool = True, output_attentions: bool = False, **kwargs):
|
88
82
|
super().__init__()
|
89
83
|
self.config = model.config
|
90
|
-
self.num_layers = self.config.decoder_layers
|
91
84
|
self.proj_out = model.proj_out
|
92
|
-
self.
|
85
|
+
self.use_attention_mask = use_attention_mask
|
93
86
|
self.output_attentions = output_attentions
|
87
|
+
self.__post_init__(model, **kwargs)
|
88
|
+
|
89
|
+
def __post_init__(self, model: nn.Module, **kwargs):
|
90
|
+
"""
|
91
|
+
Post-initialization to extract and configure encoder-related attributes.
|
92
|
+
It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
|
93
|
+
by subclasses to modify or add custom attributes as necessary.
|
94
|
+
"""
|
95
|
+
self.num_layers = self.config.decoder_layers
|
96
|
+
self.decoder = self.convert_to_rbln_conditional_generation(model)
|
94
97
|
|
95
98
|
def convert_to_rbln_conditional_generation(self, model: nn.Module):
|
96
99
|
new_layers = []
|
@@ -105,12 +108,21 @@ class WhisperDecoderWrapper(torch.nn.Module):
|
|
105
108
|
|
106
109
|
def forward(
|
107
110
|
self,
|
108
|
-
|
109
|
-
decoder_attention_mask: torch.Tensor,
|
110
|
-
cache_position: torch.Tensor,
|
111
|
-
cross_kv_cache: torch.Tensor,
|
112
|
-
*self_kv_cache: torch.Tensor,
|
111
|
+
*args,
|
113
112
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
113
|
+
if self.use_attention_mask:
|
114
|
+
(
|
115
|
+
decoder_input_ids,
|
116
|
+
decoder_attention_mask,
|
117
|
+
cache_position,
|
118
|
+
block_tables,
|
119
|
+
cross_kv_cache,
|
120
|
+
*self_kv_cache,
|
121
|
+
) = args
|
122
|
+
else:
|
123
|
+
decoder_attention_mask = None
|
124
|
+
(decoder_input_ids, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
|
125
|
+
|
114
126
|
# prepare past_key_values
|
115
127
|
self_past_key_values = ()
|
116
128
|
cross_past_key_values = ()
|
@@ -125,6 +137,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
|
|
125
137
|
cache_position=cache_position,
|
126
138
|
self_past_key_values=self_past_key_values,
|
127
139
|
cross_past_key_values=cross_past_key_values,
|
140
|
+
block_tables=block_tables,
|
128
141
|
)
|
129
142
|
|
130
143
|
lm_logits = self.proj_out(sequence_output)
|
@@ -154,17 +167,25 @@ class WhisperDecoder(nn.Module):
|
|
154
167
|
self_past_key_values: Optional[torch.Tensor] = None,
|
155
168
|
cross_past_key_values: Optional[torch.Tensor] = None,
|
156
169
|
cache_position: Optional[torch.Tensor] = None,
|
170
|
+
block_tables: Optional[torch.Tensor] = None,
|
157
171
|
):
|
158
172
|
input_shape = input_ids.size()
|
159
173
|
input_ids = input_ids.view(-1, input_shape[-1])
|
160
174
|
|
161
175
|
# positional embeding
|
162
176
|
inputs_embeds = self.embed_tokens(input_ids)
|
163
|
-
|
164
|
-
|
177
|
+
all_hiddens = []
|
178
|
+
for i in range(inputs_embeds.shape[0]):
|
179
|
+
position_id = cache_position[i]
|
180
|
+
position = self.embed_positions.weight[position_id]
|
181
|
+
batch_hidden = position + inputs_embeds[i]
|
182
|
+
all_hiddens.append(batch_hidden)
|
183
|
+
|
184
|
+
hidden_states = torch.cat(all_hiddens, dim=0).unsqueeze(1)
|
165
185
|
|
166
|
-
# prepare
|
167
|
-
|
186
|
+
# prepare attn mask (normal attention - masked)
|
187
|
+
if attention_mask is not None:
|
188
|
+
attention_mask = attention_mask[:, None, None, :]
|
168
189
|
|
169
190
|
cross_attentions = ()
|
170
191
|
# iterate decoder_layer
|
@@ -177,6 +198,7 @@ class WhisperDecoder(nn.Module):
|
|
177
198
|
self_past_key_value=self_past_key_value,
|
178
199
|
cross_past_key_value=cross_past_key_value,
|
179
200
|
cache_position=cache_position,
|
201
|
+
block_tables=block_tables,
|
180
202
|
)
|
181
203
|
cross_attentions += (cross_attn_weights,)
|
182
204
|
|
@@ -205,6 +227,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
205
227
|
self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
206
228
|
cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
207
229
|
cache_position: Optional[torch.Tensor] = None,
|
230
|
+
block_tables: Optional[torch.Tensor] = None,
|
208
231
|
) -> torch.Tensor:
|
209
232
|
# Self Attention Block
|
210
233
|
residual = hidden_states
|
@@ -214,6 +237,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
214
237
|
past_key_value=self_past_key_value,
|
215
238
|
attention_mask=attention_mask,
|
216
239
|
cache_position=cache_position,
|
240
|
+
block_tables=block_tables,
|
217
241
|
)
|
218
242
|
hidden_states = residual + hidden_states
|
219
243
|
|
@@ -263,6 +287,7 @@ class WhisperSelfAttention(WhisperAttention):
|
|
263
287
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
264
288
|
attention_mask: Optional[torch.Tensor] = None,
|
265
289
|
cache_position: Optional[torch.Tensor] = None,
|
290
|
+
block_tables: Optional[torch.Tensor] = None,
|
266
291
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
267
292
|
bsz, tgt_len, _ = hidden_states.size()
|
268
293
|
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
@@ -270,17 +295,25 @@ class WhisperSelfAttention(WhisperAttention):
|
|
270
295
|
|
271
296
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
272
297
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
280
|
-
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
281
|
-
cache_position.expand(bsz, 1),
|
282
|
-
torch.tensor(1.0, dtype=torch.float32),
|
283
|
-
|
298
|
+
block_size = past_key_value[0].shape[-2]
|
299
|
+
|
300
|
+
args = {
|
301
|
+
"q": query_states,
|
302
|
+
"k": key_states,
|
303
|
+
"v": value_states,
|
304
|
+
"kcache": past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
305
|
+
"vcache": past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
306
|
+
"seq": cache_position.expand(bsz, 1),
|
307
|
+
"scale": torch.tensor(1.0, dtype=torch.float32),
|
308
|
+
"block_table": block_tables,
|
309
|
+
"block_size": block_size,
|
310
|
+
}
|
311
|
+
|
312
|
+
if attention_mask is not None:
|
313
|
+
args["mask"] = attention_mask.unsqueeze(2)
|
314
|
+
attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(**args)
|
315
|
+
else:
|
316
|
+
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(**args)
|
284
317
|
|
285
318
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
286
319
|
attn_output = attn_output.transpose(1, 2)
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
|
16
|
+
|
17
|
+
|
18
|
+
class RBLNXLMRobertaModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
|
19
|
+
pass
|