optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +37 -2
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
- optimum/rbln/diffusers/models/controlnet.py +56 -40
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
- optimum/rbln/modeling_alias.py +3 -3
- optimum/rbln/modeling_base.py +471 -231
- optimum/rbln/modeling_config.py +152 -77
- optimum/rbln/modeling_seq2seq.py +166 -77
- optimum/rbln/transformers/__init__.py +35 -1
- optimum/rbln/transformers/models/__init__.py +20 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
- optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +36 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +19 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
- optimum_rbln-0.1.11.dist-info/RECORD +93 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -30,7 +30,6 @@ import torch
|
|
30
30
|
from transformers import (
|
31
31
|
AutoModelForSpeechSeq2Seq,
|
32
32
|
AutoProcessor,
|
33
|
-
GenerationMixin,
|
34
33
|
PretrainedConfig,
|
35
34
|
WhisperForConditionalGeneration,
|
36
35
|
WhisperModel,
|
@@ -38,8 +37,9 @@ from transformers import (
|
|
38
37
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
39
38
|
|
40
39
|
from ....modeling_base import RBLNModel
|
41
|
-
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME,
|
40
|
+
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
42
41
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
42
|
+
from .generation_whisper import RBLNWhisperGenerationMixin
|
43
43
|
from .whisper_architecture import (
|
44
44
|
_WhisperDecoderWrapper,
|
45
45
|
_WhisperEncoderWrapper,
|
@@ -60,8 +60,15 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
60
60
|
mandatory_members = ["main_input_name"]
|
61
61
|
|
62
62
|
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
63
|
-
|
64
|
-
|
63
|
+
# backward compatibility transformers==4.40.2
|
64
|
+
# https://github.com/huggingface/transformers/blob/4fdf58afb72b0754da30037fc800b6044e7d9c99/src/transformers/pipelines/automatic_speech_recognition.py#L494
|
65
|
+
input_features = kwargs.get("input_features", None)
|
66
|
+
if input_features is None:
|
67
|
+
input_features = args[0]
|
68
|
+
|
69
|
+
_ = super().forward(input_features=input_features)
|
70
|
+
# dummy output for generation
|
71
|
+
return BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
|
65
72
|
|
66
73
|
|
67
74
|
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
@@ -69,10 +76,13 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
69
76
|
|
70
77
|
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
71
78
|
outputs = super().forward(*args, **kwargs)
|
72
|
-
|
79
|
+
if isinstance(outputs, torch.Tensor):
|
80
|
+
return Seq2SeqLMOutput(logits=outputs, cross_attentions=None)
|
81
|
+
else:
|
82
|
+
return Seq2SeqLMOutput(logits=outputs[0], cross_attentions=outputs[1])
|
73
83
|
|
74
84
|
|
75
|
-
class RBLNWhisperForConditionalGeneration(RBLNModel,
|
85
|
+
class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
|
76
86
|
"""
|
77
87
|
The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
|
78
88
|
This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
@@ -88,15 +98,22 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
88
98
|
main_input_name = "input_ids"
|
89
99
|
|
90
100
|
def __post_init__(self, **kwargs):
|
91
|
-
|
92
|
-
|
93
|
-
self.
|
101
|
+
super().__post_init__(**kwargs)
|
102
|
+
|
103
|
+
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
104
|
+
self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
105
|
+
self.rbln_token_timestamps = self.rbln_config.model_cfg["token_timestamps"]
|
94
106
|
|
95
107
|
self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
|
96
108
|
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
|
97
|
-
self.forced_decoder_ids = self.config.forced_decoder_ids
|
98
109
|
|
99
|
-
#
|
110
|
+
# skip encoder & first decoder when language detected
|
111
|
+
self.is_language_detected = False
|
112
|
+
self.language_cross = None
|
113
|
+
|
114
|
+
# Used in GenerationMixin.generate()
|
115
|
+
# transformers/models/whisper/generation_whisper.py, line 505, in generate
|
116
|
+
# input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
100
117
|
self.model = WhisperModel(self.config)
|
101
118
|
self.pad_token_id = self.config.pad_token_id
|
102
119
|
|
@@ -127,63 +144,32 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
127
144
|
# TODO(jongho): implement
|
128
145
|
raise NotImplementedError
|
129
146
|
|
130
|
-
def prepare_inputs_for_generation(
|
131
|
-
self,
|
132
|
-
input_ids,
|
133
|
-
decoder_attention_mask=None,
|
134
|
-
input_features=None, # Must be explicit
|
135
|
-
**kwargs,
|
136
|
-
):
|
137
|
-
max_seq_len = self.dec_max_seq_len
|
138
|
-
cur_seq_len = input_ids.shape[-1]
|
139
|
-
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
140
|
-
decoder_attention_mask = torch.zeros(self.batch_size, max_seq_len, dtype=torch.int64)
|
141
|
-
decoder_attention_mask[:, :cur_seq_len] = 1
|
142
|
-
cache_position = torch.tensor(cur_seq_len - 1, dtype=torch.int32)
|
143
|
-
|
144
|
-
return {
|
145
|
-
"decoder_input_ids": input_ids,
|
146
|
-
"decoder_attention_mask": decoder_attention_mask,
|
147
|
-
"cache_position": cache_position,
|
148
|
-
}
|
149
|
-
|
150
|
-
@classmethod
|
151
|
-
def update_kwargs(cls, kwargs):
|
152
|
-
kwargs.update(
|
153
|
-
{
|
154
|
-
"torchscript": True,
|
155
|
-
"return_dict": False,
|
156
|
-
"use_cache": True,
|
157
|
-
}
|
158
|
-
)
|
159
|
-
return kwargs
|
160
|
-
|
161
147
|
@classmethod
|
162
148
|
@torch.inference_mode()
|
163
149
|
def get_compiled_model(cls, model, rbln_config: RBLNConfig):
|
150
|
+
rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
|
164
151
|
wrapped_encoder = _WhisperEncoderWrapper(model).eval()
|
165
|
-
wrapped_decoder = _WhisperDecoderWrapper(model).eval()
|
152
|
+
wrapped_decoder = _WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps).eval()
|
166
153
|
|
167
|
-
|
168
|
-
|
154
|
+
enc_rbln_compile_config = rbln_config.compile_cfgs[0]
|
155
|
+
dec_rbln_compile_config = rbln_config.compile_cfgs[1]
|
169
156
|
|
170
|
-
enc_example_inputs =
|
171
|
-
dec_example_inputs =
|
157
|
+
enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=1)
|
158
|
+
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=1)
|
172
159
|
|
173
|
-
enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs
|
160
|
+
enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
|
174
161
|
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
175
162
|
|
176
163
|
enc_ir = rebel.torchscript_to_ir(
|
177
164
|
enc_scripted_model,
|
178
|
-
input_names=[v[0] for v in
|
179
|
-
name=
|
165
|
+
input_names=[v[0] for v in enc_rbln_compile_config.input_info],
|
166
|
+
name=enc_rbln_compile_config.mod_name,
|
180
167
|
)
|
181
168
|
dec_ir = rebel.torchscript_to_ir(
|
182
169
|
dec_scripted_model,
|
183
|
-
input_names=[v[0] for v in
|
184
|
-
name=
|
170
|
+
input_names=[v[0] for v in dec_rbln_compile_config.input_info],
|
171
|
+
name=dec_rbln_compile_config.mod_name,
|
185
172
|
)
|
186
|
-
dec_ir.batch_size = dec_rbln_runtime_config.batch_size
|
187
173
|
|
188
174
|
# Caching encoder/decoder I/O
|
189
175
|
connections = [
|
@@ -194,9 +180,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
194
180
|
enc_ir,
|
195
181
|
dec_ir,
|
196
182
|
connections=connections,
|
197
|
-
fusion=
|
198
|
-
npu=
|
199
|
-
tensor_parallel_size=
|
183
|
+
fusion=enc_rbln_compile_config.fusion,
|
184
|
+
npu=enc_rbln_compile_config.npu,
|
185
|
+
tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
|
200
186
|
)
|
201
187
|
return compiled_model
|
202
188
|
|
@@ -205,42 +191,22 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
205
191
|
cls,
|
206
192
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
|
207
193
|
model_config: "PretrainedConfig",
|
208
|
-
|
194
|
+
rbln_kwargs: Dict[str, Any] = {},
|
209
195
|
) -> RBLNConfig:
|
210
|
-
|
211
|
-
|
212
|
-
input_max_length = 3000
|
213
|
-
rbln_enc_num_mel_bins = getattr(model_config, "num_mel_bins", None)
|
214
|
-
if rbln_enc_num_mel_bins is None:
|
215
|
-
for feature_extractor in preprocessors:
|
216
|
-
if hasattr(feature_extractor, "feature_size"):
|
217
|
-
rbln_enc_num_mel_bins = feature_extractor.feature_size
|
218
|
-
break
|
219
|
-
raise ValueError("`rbln_enc_num_mel_bins` should be specified!")
|
220
|
-
|
221
|
-
rbln_enc_max_seq_len = getattr(model_config, "max_source_positions", None)
|
222
|
-
if rbln_enc_max_seq_len is None:
|
223
|
-
raise ValueError("`rbln_enc_max_seq_len` should be specified!")
|
224
|
-
|
225
|
-
rbln_dec_max_seq_len = getattr(model_config, "max_length", None)
|
226
|
-
if rbln_dec_max_seq_len is None:
|
227
|
-
raise ValueError("`rbln_dec_max_seq_len` should be specified!")
|
228
|
-
|
196
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
197
|
+
rbln_token_timestamps = rbln_kwargs.get("token_timestamps", False)
|
229
198
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
230
|
-
decoder_batch_size = rbln_batch_size
|
231
199
|
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
meta["decoder_batch_size"] = decoder_batch_size
|
237
|
-
meta["forced_decoder_ids"] = model_config.forced_decoder_ids
|
200
|
+
expected_seq_len = model_config.max_source_positions * 2
|
201
|
+
num_mel_bins = model_config.num_mel_bins
|
202
|
+
enc_max_seq_len = model_config.max_source_positions
|
203
|
+
rbln_dec_max_seq_len = model_config.max_length
|
238
204
|
|
239
205
|
# model input info
|
240
|
-
enc_input_info = [("input_features", [rbln_batch_size,
|
206
|
+
enc_input_info = [("input_features", [rbln_batch_size, num_mel_bins, expected_seq_len], "float32")]
|
241
207
|
dec_input_info = [
|
242
|
-
("decoder_input_ids", [
|
243
|
-
("decoder_attention_mask", [
|
208
|
+
("decoder_input_ids", [rbln_batch_size, 1], "int64"),
|
209
|
+
("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
|
244
210
|
("cache_position", [], "int32"),
|
245
211
|
]
|
246
212
|
dec_input_info.extend(
|
@@ -249,7 +215,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
249
215
|
"self_key_value_states",
|
250
216
|
[
|
251
217
|
model_config.decoder_layers * 2,
|
252
|
-
|
218
|
+
rbln_batch_size,
|
253
219
|
model_config.decoder_attention_heads,
|
254
220
|
rbln_dec_max_seq_len,
|
255
221
|
model_config.d_model // model_config.encoder_attention_heads,
|
@@ -266,7 +232,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
266
232
|
model_config.decoder_layers * 2,
|
267
233
|
rbln_batch_size,
|
268
234
|
model_config.decoder_attention_heads,
|
269
|
-
|
235
|
+
enc_max_seq_len,
|
270
236
|
model_config.d_model // model_config.encoder_attention_heads,
|
271
237
|
],
|
272
238
|
"float32",
|
@@ -274,15 +240,21 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
274
240
|
]
|
275
241
|
)
|
276
242
|
|
277
|
-
|
278
|
-
|
243
|
+
enc_rbln_compile_config = RBLNCompileConfig(mod_name="encoder", input_info=enc_input_info)
|
244
|
+
dec_rbln_compile_config = RBLNCompileConfig(mod_name="decoder", input_info=dec_input_info)
|
279
245
|
|
280
|
-
|
281
|
-
|
246
|
+
rbln_config = RBLNConfig(
|
247
|
+
rbln_cls=cls.__name__,
|
248
|
+
compile_cfgs=[enc_rbln_compile_config, dec_rbln_compile_config],
|
249
|
+
rbln_kwargs=rbln_kwargs,
|
250
|
+
)
|
282
251
|
|
283
|
-
rbln_config
|
284
|
-
|
285
|
-
|
252
|
+
rbln_config.model_cfg.update(
|
253
|
+
{
|
254
|
+
"batch_size": rbln_batch_size,
|
255
|
+
"dec_max_seq_len": rbln_dec_max_seq_len,
|
256
|
+
"token_timestamps": rbln_token_timestamps,
|
257
|
+
}
|
286
258
|
)
|
287
259
|
|
288
260
|
return rbln_config
|
@@ -297,18 +269,82 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
297
269
|
compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
|
298
270
|
]
|
299
271
|
|
272
|
+
def prepare_inputs_for_generation(
|
273
|
+
self,
|
274
|
+
input_ids,
|
275
|
+
cache_position: Optional[torch.Tensor] = None,
|
276
|
+
**kwargs,
|
277
|
+
):
|
278
|
+
"""
|
279
|
+
whisper don't use attention_mask,
|
280
|
+
attention_mask (`torch.Tensor`)`, *optional*):
|
281
|
+
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
|
282
|
+
but it is not used. By default the silence in the input log mel spectrogram are ignored.
|
283
|
+
"""
|
284
|
+
return {
|
285
|
+
"input_ids": input_ids,
|
286
|
+
"cache_position": cache_position,
|
287
|
+
}
|
288
|
+
|
289
|
+
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L512
|
290
|
+
def _prepare_encoder_decoder_kwargs_for_generation(
|
291
|
+
self, inputs_tensor: torch.Tensor, model_kwargs, *args, **kwargs
|
292
|
+
) -> Dict[str, Any]:
|
293
|
+
if not self.is_language_detected:
|
294
|
+
model_kwargs["encoder_outputs"] = self.encoder(input_features=inputs_tensor)
|
295
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
|
296
|
+
else:
|
297
|
+
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
|
298
|
+
|
299
|
+
return model_kwargs
|
300
|
+
|
300
301
|
def forward(
|
301
302
|
self,
|
302
|
-
|
303
|
-
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
303
|
+
input_ids: Optional[torch.LongTensor] = None,
|
304
304
|
cache_position: Optional[torch.Tensor] = None,
|
305
|
+
input_features: Optional[torch.Tensor] = None,
|
306
|
+
decoder_input_ids: Optional[torch.Tensor] = None,
|
307
|
+
encoder_outputs: Optional[Seq2SeqLMOutput] = None,
|
305
308
|
**kwargs,
|
306
309
|
) -> Seq2SeqLMOutput:
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
cache_position
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
310
|
+
# default decoder pass
|
311
|
+
if input_features is None and encoder_outputs is None:
|
312
|
+
cross_attentions = []
|
313
|
+
for step in cache_position:
|
314
|
+
# skip step 0 if language_detection has been processed
|
315
|
+
if step == 0 and self.is_language_detected:
|
316
|
+
cross_attentions.append(self.language_cross)
|
317
|
+
self.is_language_detected = False
|
318
|
+
else:
|
319
|
+
self.decoder_attention_mask[:, step] = 1
|
320
|
+
decoder_output = self.decoder(
|
321
|
+
decoder_input_ids=input_ids[:, step : step + 1].contiguous(),
|
322
|
+
decoder_attention_mask=self.decoder_attention_mask,
|
323
|
+
cache_position=step.to(torch.int32),
|
324
|
+
)
|
325
|
+
cross_attentions.append(decoder_output.cross_attentions)
|
326
|
+
lm_logits = decoder_output.logits
|
327
|
+
|
328
|
+
if self.rbln_token_timestamps:
|
329
|
+
cross_attentions = torch.cat(cross_attentions, dim=-2)
|
330
|
+
else:
|
331
|
+
cross_attentions = None
|
332
|
+
|
333
|
+
return Seq2SeqLMOutput(logits=lm_logits, cross_attentions=cross_attentions)
|
334
|
+
|
335
|
+
# detect language pass
|
336
|
+
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/whisper/generation_whisper.py#L1442
|
337
|
+
else:
|
338
|
+
if encoder_outputs is None:
|
339
|
+
self.encoder(input_features=input_features.contiguous())
|
340
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
|
341
|
+
self.is_language_detected = True
|
342
|
+
self.decoder_attention_mask[:, 0] = 1
|
343
|
+
decoder_output = self.decoder(
|
344
|
+
decoder_input_ids=decoder_input_ids.contiguous(),
|
345
|
+
decoder_attention_mask=self.decoder_attention_mask,
|
346
|
+
cache_position=torch.zeros([], dtype=torch.int32),
|
347
|
+
)
|
348
|
+
lm_logits = decoder_output.logits
|
349
|
+
self.language_cross = decoder_output.cross_attentions
|
350
|
+
return Seq2SeqLMOutput(logits=lm_logits)
|
@@ -55,7 +55,6 @@ class _WhisperAttention(WhisperAttention):
|
|
55
55
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
56
56
|
attention_mask: Optional[torch.Tensor] = None,
|
57
57
|
cache_position: Optional[torch.Tensor] = None,
|
58
|
-
**kwargs,
|
59
58
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
60
59
|
bsz, tgt_len, _ = hidden_states.size()
|
61
60
|
is_cross_attention = key_value_states is not None
|
@@ -99,6 +98,7 @@ class _WhisperAttention(WhisperAttention):
|
|
99
98
|
if attention_mask is not None:
|
100
99
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
101
100
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
101
|
+
|
102
102
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
103
103
|
|
104
104
|
attn_output = torch.bmm(attn_weights, value_states)
|
@@ -109,7 +109,9 @@ class _WhisperAttention(WhisperAttention):
|
|
109
109
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
110
110
|
attn_output = self.out_proj(attn_output)
|
111
111
|
|
112
|
-
|
112
|
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
113
|
+
|
114
|
+
return attn_output, attn_weights, present_key_value
|
113
115
|
|
114
116
|
|
115
117
|
class _WhisperSdpaAttention(WhisperSdpaAttention):
|
@@ -186,6 +188,7 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
|
|
186
188
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
187
189
|
cache_position: Optional[torch.Tensor] = None,
|
188
190
|
attn_impl: str = "eager",
|
191
|
+
output_attentions: bool = False,
|
189
192
|
) -> torch.Tensor:
|
190
193
|
# Self Attention Block
|
191
194
|
residual = hidden_states
|
@@ -205,14 +208,22 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
|
|
205
208
|
residual = hidden_states
|
206
209
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
207
210
|
cross_attn_past_key_value = past_key_value[2:] if past_key_value is not None else None
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
211
|
+
if output_attentions:
|
212
|
+
hidden_states, cross_attn_weights, cross_attn_present_key_value = _WhisperAttention.forward(
|
213
|
+
self.encoder_attn,
|
214
|
+
hidden_states=hidden_states,
|
215
|
+
key_value_states=encoder_hidden_states,
|
216
|
+
past_key_value=cross_attn_past_key_value,
|
217
|
+
cache_position=cache_position,
|
218
|
+
)
|
219
|
+
else:
|
220
|
+
hidden_states, cross_attn_weights, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
|
221
|
+
self.encoder_attn,
|
222
|
+
hidden_states=hidden_states,
|
223
|
+
key_value_states=encoder_hidden_states,
|
224
|
+
past_key_value=cross_attn_past_key_value,
|
225
|
+
cache_position=cache_position,
|
226
|
+
)
|
216
227
|
hidden_states = residual + hidden_states
|
217
228
|
present_key_value = present_key_value + cross_attn_present_key_value
|
218
229
|
|
@@ -223,7 +234,7 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
|
|
223
234
|
hidden_states = self.fc2(hidden_states)
|
224
235
|
hidden_states = residual + hidden_states
|
225
236
|
|
226
|
-
return hidden_states, present_key_value
|
237
|
+
return hidden_states, present_key_value, cross_attn_weights
|
227
238
|
|
228
239
|
|
229
240
|
class _WhisperPositionalEmbedding(WhisperPositionalEmbedding):
|
@@ -243,6 +254,7 @@ class _WhisperDecoder(WhisperDecoder):
|
|
243
254
|
past_key_values: Optional[torch.Tensor] = None,
|
244
255
|
cache_position: Optional[torch.Tensor] = None,
|
245
256
|
attn_impl: str = "eager",
|
257
|
+
output_attentions: bool = False,
|
246
258
|
**kwargs,
|
247
259
|
):
|
248
260
|
input_shape = input_ids.size()
|
@@ -266,6 +278,7 @@ class _WhisperDecoder(WhisperDecoder):
|
|
266
278
|
)
|
267
279
|
|
268
280
|
next_decoder_cache = ()
|
281
|
+
all_cross_attentions = () if output_attentions else None
|
269
282
|
# iterate decoder_layer
|
270
283
|
for idx, decoder_layer in enumerate(self.layers):
|
271
284
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
@@ -277,10 +290,13 @@ class _WhisperDecoder(WhisperDecoder):
|
|
277
290
|
past_key_value=past_key_value,
|
278
291
|
cache_position=cache_position,
|
279
292
|
attn_impl=attn_impl,
|
293
|
+
output_attentions=output_attentions,
|
280
294
|
)
|
281
295
|
hidden_states = layer_outputs[0]
|
282
296
|
|
283
297
|
next_decoder_cache += (layer_outputs[1],)
|
298
|
+
if output_attentions:
|
299
|
+
all_cross_attentions += (layer_outputs[2],)
|
284
300
|
|
285
301
|
# layer_norm
|
286
302
|
hidden_states = self.layer_norm(hidden_states)
|
@@ -288,17 +304,19 @@ class _WhisperDecoder(WhisperDecoder):
|
|
288
304
|
return BaseModelOutputWithPastAndCrossAttentions(
|
289
305
|
last_hidden_state=hidden_states,
|
290
306
|
past_key_values=next_decoder_cache,
|
307
|
+
cross_attentions=all_cross_attentions,
|
291
308
|
)
|
292
309
|
|
293
310
|
|
294
311
|
class _WhisperDecoderWrapper(torch.nn.Module):
|
295
|
-
def __init__(self, model):
|
312
|
+
def __init__(self, model, output_attentions: bool = False):
|
296
313
|
super().__init__()
|
297
314
|
self.proj_out = model.proj_out
|
298
315
|
self.config = model.config
|
299
316
|
self.decoder = model.get_decoder()
|
300
317
|
self.num_layers = self.config.decoder_layers
|
301
318
|
self.attn_impl = self.config._attn_implementation
|
319
|
+
self.output_attentions = output_attentions
|
302
320
|
|
303
321
|
def forward(
|
304
322
|
self,
|
@@ -329,6 +347,7 @@ class _WhisperDecoderWrapper(torch.nn.Module):
|
|
329
347
|
past_key_values=kv_cache,
|
330
348
|
encoder_hidden_states=torch.tensor([1]),
|
331
349
|
attn_impl=self.attn_impl,
|
350
|
+
output_attentions=self.output_attentions,
|
332
351
|
)
|
333
352
|
sequence_output = decoder_outputs[0]
|
334
353
|
lm_logits = self.proj_out(sequence_output)
|
@@ -341,7 +360,12 @@ class _WhisperDecoderWrapper(torch.nn.Module):
|
|
341
360
|
self_kv_cache.append(past_key_values[i][1])
|
342
361
|
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
343
362
|
|
344
|
-
|
363
|
+
if self.output_attentions:
|
364
|
+
# deocder's cross attention is used for token_timestamps
|
365
|
+
cross_attention = torch.stack(decoder_outputs[2], dim=0)
|
366
|
+
return lm_logits, self_kv_cache, cross_attention
|
367
|
+
else:
|
368
|
+
return lm_logits, self_kv_cache
|
345
369
|
|
346
370
|
|
347
371
|
class _WhisperEncoderWrapper(torch.nn.Module):
|
@@ -363,6 +387,7 @@ class _WhisperEncoderWrapper(torch.nn.Module):
|
|
363
387
|
input_features: Optional[torch.LongTensor] = None,
|
364
388
|
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
365
389
|
encoder_outputs = self.encoder(input_features=input_features)
|
390
|
+
|
366
391
|
last_hidden_states = encoder_outputs[0]
|
367
392
|
|
368
393
|
encoder_batch_size = input_features.shape[0]
|
@@ -388,13 +413,15 @@ class _WhisperEncoderWrapper(torch.nn.Module):
|
|
388
413
|
encoder_hidden_states=last_hidden_states,
|
389
414
|
past_key_values=dummy_past_key_value,
|
390
415
|
attn_impl=self.attn_impl,
|
416
|
+
output_attentions=False,
|
391
417
|
)
|
392
418
|
|
393
419
|
first_past_kv = decoder_outputs[1]
|
394
420
|
|
395
|
-
|
421
|
+
cross_kv = []
|
396
422
|
for layer_out in first_past_kv: # for layer
|
397
|
-
|
398
|
-
|
423
|
+
cross_kv.append(layer_out[2])
|
424
|
+
cross_kv.append(layer_out[3])
|
425
|
+
cross_kv = torch.stack(cross_kv, dim=0)
|
399
426
|
|
400
|
-
return
|
427
|
+
return cross_kv
|
@@ -22,13 +22,13 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from typing import TYPE_CHECKING, Any, Dict,
|
25
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
26
26
|
|
27
27
|
import torch
|
28
28
|
from transformers import AutoModel, PretrainedConfig, PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
|
29
29
|
|
30
30
|
from ....modeling_base import RBLNModel
|
31
|
-
from ....modeling_config import
|
31
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
32
|
|
33
33
|
|
34
34
|
logger = logging.getLogger(__name__)
|
@@ -53,8 +53,7 @@ class RBLNXLMRobertaModel(RBLNModel):
|
|
53
53
|
subfolder: str = "",
|
54
54
|
local_files_only: bool = False,
|
55
55
|
trust_remote_code: bool = False,
|
56
|
-
|
57
|
-
rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
|
56
|
+
rbln_kwargs: Optional[Dict[str, Any]] = None,
|
58
57
|
**kwargs,
|
59
58
|
) -> "PreTrainedModel":
|
60
59
|
model: "PreTrainedModel" = super().get_pytorch_model(
|
@@ -66,8 +65,7 @@ class RBLNXLMRobertaModel(RBLNModel):
|
|
66
65
|
subfolder=subfolder,
|
67
66
|
local_files_only=local_files_only,
|
68
67
|
trust_remote_code=trust_remote_code,
|
69
|
-
|
70
|
-
rbln_constructor_kwargs=rbln_constructor_kwargs,
|
68
|
+
rbln_kwargs=rbln_kwargs,
|
71
69
|
library_name="transformers",
|
72
70
|
)
|
73
71
|
|
@@ -78,10 +76,12 @@ class RBLNXLMRobertaModel(RBLNModel):
|
|
78
76
|
cls,
|
79
77
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
80
78
|
model_config: Optional["PretrainedConfig"] = None,
|
81
|
-
|
82
|
-
rbln_model_input_names: Optional[List[str]] = None,
|
83
|
-
rbln_batch_size: Optional[int] = None,
|
79
|
+
rbln_kwargs={},
|
84
80
|
) -> RBLNConfig:
|
81
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
82
|
+
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
83
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
84
|
+
|
85
85
|
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
86
86
|
model_config, "max_position_embeddings", None
|
87
87
|
)
|
@@ -111,12 +111,15 @@ class RBLNXLMRobertaModel(RBLNModel):
|
|
111
111
|
for model_input_name in rbln_model_input_names
|
112
112
|
]
|
113
113
|
|
114
|
-
|
115
|
-
rbln_runtime_config.batch_size = rbln_batch_size
|
116
|
-
|
117
|
-
meta = {"rbln_max_seq_len": rbln_max_seq_len}
|
114
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
118
115
|
|
119
|
-
|
116
|
+
rbln_config = RBLNConfig(
|
117
|
+
rbln_cls=cls.__name__,
|
118
|
+
compile_cfgs=[rbln_compile_config],
|
119
|
+
rbln_kwargs=rbln_kwargs,
|
120
|
+
)
|
121
|
+
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
122
|
+
return rbln_config
|
120
123
|
|
121
124
|
def forward(
|
122
125
|
self,
|