optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. optimum/rbln/__init__.py +37 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
  4. optimum/rbln/diffusers/models/controlnet.py +56 -40
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  10. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  14. optimum/rbln/modeling_alias.py +3 -3
  15. optimum/rbln/modeling_base.py +471 -231
  16. optimum/rbln/modeling_config.py +152 -77
  17. optimum/rbln/modeling_seq2seq.py +166 -77
  18. optimum/rbln/transformers/__init__.py +35 -1
  19. optimum/rbln/transformers/models/__init__.py +20 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  33. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  34. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  35. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  37. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  38. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
  41. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  42. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  43. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  44. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  45. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  46. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
  47. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  48. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  49. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  50. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
  51. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  52. optimum/rbln/utils/import_utils.py +36 -1
  53. optimum/rbln/utils/logging.py +82 -0
  54. optimum/rbln/utils/runtime_utils.py +33 -0
  55. optimum/rbln/utils/timer_utils.py +19 -0
  56. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
  57. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  58. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  59. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  60. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  61. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -30,7 +30,6 @@ import torch
30
30
  from transformers import (
31
31
  AutoModelForSpeechSeq2Seq,
32
32
  AutoProcessor,
33
- GenerationMixin,
34
33
  PretrainedConfig,
35
34
  WhisperForConditionalGeneration,
36
35
  WhisperModel,
@@ -38,8 +37,9 @@ from transformers import (
38
37
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
39
38
 
40
39
  from ....modeling_base import RBLNModel
41
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
40
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
42
41
  from ....utils.runtime_utils import RBLNPytorchRuntime
42
+ from .generation_whisper import RBLNWhisperGenerationMixin
43
43
  from .whisper_architecture import (
44
44
  _WhisperDecoderWrapper,
45
45
  _WhisperEncoderWrapper,
@@ -60,8 +60,15 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
60
60
  mandatory_members = ["main_input_name"]
61
61
 
62
62
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
63
- _ = super().forward(input_features=kwargs["input_features"])
64
- return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
63
+ # backward compatibility transformers==4.40.2
64
+ # https://github.com/huggingface/transformers/blob/4fdf58afb72b0754da30037fc800b6044e7d9c99/src/transformers/pipelines/automatic_speech_recognition.py#L494
65
+ input_features = kwargs.get("input_features", None)
66
+ if input_features is None:
67
+ input_features = args[0]
68
+
69
+ _ = super().forward(input_features=input_features)
70
+ # dummy output for generation
71
+ return BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
65
72
 
66
73
 
67
74
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
@@ -69,10 +76,13 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
69
76
 
70
77
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
71
78
  outputs = super().forward(*args, **kwargs)
72
- return Seq2SeqLMOutput(logits=outputs)
79
+ if isinstance(outputs, torch.Tensor):
80
+ return Seq2SeqLMOutput(logits=outputs, cross_attentions=None)
81
+ else:
82
+ return Seq2SeqLMOutput(logits=outputs[0], cross_attentions=outputs[1])
73
83
 
74
84
 
75
- class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
85
+ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
76
86
  """
77
87
  The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
78
88
  This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -88,15 +98,22 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
88
98
  main_input_name = "input_ids"
89
99
 
90
100
  def __post_init__(self, **kwargs):
91
- self.batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].batch_size
92
- self.enc_max_seq_len = self.rbln_config.meta["input_max_length"]
93
- self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
101
+ super().__post_init__(**kwargs)
102
+
103
+ self.batch_size = self.rbln_config.model_cfg["batch_size"]
104
+ self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
105
+ self.rbln_token_timestamps = self.rbln_config.model_cfg["token_timestamps"]
94
106
 
95
107
  self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
96
108
  self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
97
- self.forced_decoder_ids = self.config.forced_decoder_ids
98
109
 
99
- # used in GenerationMixin.generate()
110
+ # skip encoder & first decoder when language detected
111
+ self.is_language_detected = False
112
+ self.language_cross = None
113
+
114
+ # Used in GenerationMixin.generate()
115
+ # transformers/models/whisper/generation_whisper.py, line 505, in generate
116
+ # input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
100
117
  self.model = WhisperModel(self.config)
101
118
  self.pad_token_id = self.config.pad_token_id
102
119
 
@@ -127,63 +144,32 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
127
144
  # TODO(jongho): implement
128
145
  raise NotImplementedError
129
146
 
130
- def prepare_inputs_for_generation(
131
- self,
132
- input_ids,
133
- decoder_attention_mask=None,
134
- input_features=None, # Must be explicit
135
- **kwargs,
136
- ):
137
- max_seq_len = self.dec_max_seq_len
138
- cur_seq_len = input_ids.shape[-1]
139
- input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
140
- decoder_attention_mask = torch.zeros(self.batch_size, max_seq_len, dtype=torch.int64)
141
- decoder_attention_mask[:, :cur_seq_len] = 1
142
- cache_position = torch.tensor(cur_seq_len - 1, dtype=torch.int32)
143
-
144
- return {
145
- "decoder_input_ids": input_ids,
146
- "decoder_attention_mask": decoder_attention_mask,
147
- "cache_position": cache_position,
148
- }
149
-
150
- @classmethod
151
- def update_kwargs(cls, kwargs):
152
- kwargs.update(
153
- {
154
- "torchscript": True,
155
- "return_dict": False,
156
- "use_cache": True,
157
- }
158
- )
159
- return kwargs
160
-
161
147
  @classmethod
162
148
  @torch.inference_mode()
163
149
  def get_compiled_model(cls, model, rbln_config: RBLNConfig):
150
+ rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
164
151
  wrapped_encoder = _WhisperEncoderWrapper(model).eval()
165
- wrapped_decoder = _WhisperDecoderWrapper(model).eval()
152
+ wrapped_decoder = _WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps).eval()
166
153
 
167
- enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
168
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
154
+ enc_rbln_compile_config = rbln_config.compile_cfgs[0]
155
+ dec_rbln_compile_config = rbln_config.compile_cfgs[1]
169
156
 
170
- enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
171
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
157
+ enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=1)
158
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=1)
172
159
 
173
- enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs[0], check_trace=False)
160
+ enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
174
161
  dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
175
162
 
176
163
  enc_ir = rebel.torchscript_to_ir(
177
164
  enc_scripted_model,
178
- input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
179
- name=enc_rbln_runtime_config.rbln_mod_name,
165
+ input_names=[v[0] for v in enc_rbln_compile_config.input_info],
166
+ name=enc_rbln_compile_config.mod_name,
180
167
  )
181
168
  dec_ir = rebel.torchscript_to_ir(
182
169
  dec_scripted_model,
183
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
184
- name=dec_rbln_runtime_config.rbln_mod_name,
170
+ input_names=[v[0] for v in dec_rbln_compile_config.input_info],
171
+ name=dec_rbln_compile_config.mod_name,
185
172
  )
186
- dec_ir.batch_size = dec_rbln_runtime_config.batch_size
187
173
 
188
174
  # Caching encoder/decoder I/O
189
175
  connections = [
@@ -194,9 +180,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
194
180
  enc_ir,
195
181
  dec_ir,
196
182
  connections=connections,
197
- fusion=enc_rbln_runtime_config.fusion,
198
- npu=enc_rbln_runtime_config.npu,
199
- tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
183
+ fusion=enc_rbln_compile_config.fusion,
184
+ npu=enc_rbln_compile_config.npu,
185
+ tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
200
186
  )
201
187
  return compiled_model
202
188
 
@@ -205,42 +191,22 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
205
191
  cls,
206
192
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
207
193
  model_config: "PretrainedConfig",
208
- rbln_batch_size: Optional[int] = None,
194
+ rbln_kwargs: Dict[str, Any] = {},
209
195
  ) -> RBLNConfig:
210
- meta = {}
211
-
212
- input_max_length = 3000
213
- rbln_enc_num_mel_bins = getattr(model_config, "num_mel_bins", None)
214
- if rbln_enc_num_mel_bins is None:
215
- for feature_extractor in preprocessors:
216
- if hasattr(feature_extractor, "feature_size"):
217
- rbln_enc_num_mel_bins = feature_extractor.feature_size
218
- break
219
- raise ValueError("`rbln_enc_num_mel_bins` should be specified!")
220
-
221
- rbln_enc_max_seq_len = getattr(model_config, "max_source_positions", None)
222
- if rbln_enc_max_seq_len is None:
223
- raise ValueError("`rbln_enc_max_seq_len` should be specified!")
224
-
225
- rbln_dec_max_seq_len = getattr(model_config, "max_length", None)
226
- if rbln_dec_max_seq_len is None:
227
- raise ValueError("`rbln_dec_max_seq_len` should be specified!")
228
-
196
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
197
+ rbln_token_timestamps = rbln_kwargs.get("token_timestamps", False)
229
198
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
230
- decoder_batch_size = rbln_batch_size
231
199
 
232
- meta["rbln_dec_max_seq_len"] = rbln_dec_max_seq_len
233
- meta["rbln_enc_max_seq_len"] = rbln_enc_max_seq_len
234
- meta["num_mel_bins"] = rbln_enc_num_mel_bins
235
- meta["input_max_length"] = input_max_length
236
- meta["decoder_batch_size"] = decoder_batch_size
237
- meta["forced_decoder_ids"] = model_config.forced_decoder_ids
200
+ expected_seq_len = model_config.max_source_positions * 2
201
+ num_mel_bins = model_config.num_mel_bins
202
+ enc_max_seq_len = model_config.max_source_positions
203
+ rbln_dec_max_seq_len = model_config.max_length
238
204
 
239
205
  # model input info
240
- enc_input_info = [("input_features", [rbln_batch_size, rbln_enc_num_mel_bins, input_max_length], "float32")]
206
+ enc_input_info = [("input_features", [rbln_batch_size, num_mel_bins, expected_seq_len], "float32")]
241
207
  dec_input_info = [
242
- ("decoder_input_ids", [decoder_batch_size, 1], "int64"),
243
- ("decoder_attention_mask", [decoder_batch_size, rbln_dec_max_seq_len], "int64"),
208
+ ("decoder_input_ids", [rbln_batch_size, 1], "int64"),
209
+ ("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
244
210
  ("cache_position", [], "int32"),
245
211
  ]
246
212
  dec_input_info.extend(
@@ -249,7 +215,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
249
215
  "self_key_value_states",
250
216
  [
251
217
  model_config.decoder_layers * 2,
252
- decoder_batch_size,
218
+ rbln_batch_size,
253
219
  model_config.decoder_attention_heads,
254
220
  rbln_dec_max_seq_len,
255
221
  model_config.d_model // model_config.encoder_attention_heads,
@@ -266,7 +232,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
266
232
  model_config.decoder_layers * 2,
267
233
  rbln_batch_size,
268
234
  model_config.decoder_attention_heads,
269
- rbln_enc_max_seq_len,
235
+ enc_max_seq_len,
270
236
  model_config.d_model // model_config.encoder_attention_heads,
271
237
  ],
272
238
  "float32",
@@ -274,15 +240,21 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
274
240
  ]
275
241
  )
276
242
 
277
- enc_rbln_runtime_config = RBLNRuntimeConfig(rbln_mod_name="encoder", input_info=enc_input_info)
278
- dec_rbln_runtime_config = RBLNRuntimeConfig(rbln_mod_name="decoder", input_info=dec_input_info)
243
+ enc_rbln_compile_config = RBLNCompileConfig(mod_name="encoder", input_info=enc_input_info)
244
+ dec_rbln_compile_config = RBLNCompileConfig(mod_name="decoder", input_info=dec_input_info)
279
245
 
280
- enc_rbln_runtime_config.batch_size = rbln_batch_size
281
- dec_rbln_runtime_config.batch_size = decoder_batch_size
246
+ rbln_config = RBLNConfig(
247
+ rbln_cls=cls.__name__,
248
+ compile_cfgs=[enc_rbln_compile_config, dec_rbln_compile_config],
249
+ rbln_kwargs=rbln_kwargs,
250
+ )
282
251
 
283
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
284
- [enc_rbln_runtime_config, dec_rbln_runtime_config],
285
- _rbln_meta=meta,
252
+ rbln_config.model_cfg.update(
253
+ {
254
+ "batch_size": rbln_batch_size,
255
+ "dec_max_seq_len": rbln_dec_max_seq_len,
256
+ "token_timestamps": rbln_token_timestamps,
257
+ }
286
258
  )
287
259
 
288
260
  return rbln_config
@@ -297,18 +269,82 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
297
269
  compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
298
270
  ]
299
271
 
272
+ def prepare_inputs_for_generation(
273
+ self,
274
+ input_ids,
275
+ cache_position: Optional[torch.Tensor] = None,
276
+ **kwargs,
277
+ ):
278
+ """
279
+ whisper don't use attention_mask,
280
+ attention_mask (`torch.Tensor`)`, *optional*):
281
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
282
+ but it is not used. By default the silence in the input log mel spectrogram are ignored.
283
+ """
284
+ return {
285
+ "input_ids": input_ids,
286
+ "cache_position": cache_position,
287
+ }
288
+
289
+ # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L512
290
+ def _prepare_encoder_decoder_kwargs_for_generation(
291
+ self, inputs_tensor: torch.Tensor, model_kwargs, *args, **kwargs
292
+ ) -> Dict[str, Any]:
293
+ if not self.is_language_detected:
294
+ model_kwargs["encoder_outputs"] = self.encoder(input_features=inputs_tensor)
295
+ self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
296
+ else:
297
+ model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
298
+
299
+ return model_kwargs
300
+
300
301
  def forward(
301
302
  self,
302
- decoder_input_ids: Optional[torch.LongTensor] = None,
303
- decoder_attention_mask: Optional[torch.LongTensor] = None,
303
+ input_ids: Optional[torch.LongTensor] = None,
304
304
  cache_position: Optional[torch.Tensor] = None,
305
+ input_features: Optional[torch.Tensor] = None,
306
+ decoder_input_ids: Optional[torch.Tensor] = None,
307
+ encoder_outputs: Optional[Seq2SeqLMOutput] = None,
305
308
  **kwargs,
306
309
  ) -> Seq2SeqLMOutput:
307
- decoder_output = self.decoder(
308
- decoder_input_ids=decoder_input_ids,
309
- decoder_attention_mask=decoder_attention_mask,
310
- cache_position=cache_position,
311
- )
312
- lm_logits = decoder_output.logits
313
-
314
- return Seq2SeqLMOutput(logits=lm_logits)
310
+ # default decoder pass
311
+ if input_features is None and encoder_outputs is None:
312
+ cross_attentions = []
313
+ for step in cache_position:
314
+ # skip step 0 if language_detection has been processed
315
+ if step == 0 and self.is_language_detected:
316
+ cross_attentions.append(self.language_cross)
317
+ self.is_language_detected = False
318
+ else:
319
+ self.decoder_attention_mask[:, step] = 1
320
+ decoder_output = self.decoder(
321
+ decoder_input_ids=input_ids[:, step : step + 1].contiguous(),
322
+ decoder_attention_mask=self.decoder_attention_mask,
323
+ cache_position=step.to(torch.int32),
324
+ )
325
+ cross_attentions.append(decoder_output.cross_attentions)
326
+ lm_logits = decoder_output.logits
327
+
328
+ if self.rbln_token_timestamps:
329
+ cross_attentions = torch.cat(cross_attentions, dim=-2)
330
+ else:
331
+ cross_attentions = None
332
+
333
+ return Seq2SeqLMOutput(logits=lm_logits, cross_attentions=cross_attentions)
334
+
335
+ # detect language pass
336
+ # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/whisper/generation_whisper.py#L1442
337
+ else:
338
+ if encoder_outputs is None:
339
+ self.encoder(input_features=input_features.contiguous())
340
+ self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
341
+ self.is_language_detected = True
342
+ self.decoder_attention_mask[:, 0] = 1
343
+ decoder_output = self.decoder(
344
+ decoder_input_ids=decoder_input_ids.contiguous(),
345
+ decoder_attention_mask=self.decoder_attention_mask,
346
+ cache_position=torch.zeros([], dtype=torch.int32),
347
+ )
348
+ lm_logits = decoder_output.logits
349
+ self.language_cross = decoder_output.cross_attentions
350
+ return Seq2SeqLMOutput(logits=lm_logits)
@@ -55,7 +55,6 @@ class _WhisperAttention(WhisperAttention):
55
55
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
56
56
  attention_mask: Optional[torch.Tensor] = None,
57
57
  cache_position: Optional[torch.Tensor] = None,
58
- **kwargs,
59
58
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
60
59
  bsz, tgt_len, _ = hidden_states.size()
61
60
  is_cross_attention = key_value_states is not None
@@ -99,6 +98,7 @@ class _WhisperAttention(WhisperAttention):
99
98
  if attention_mask is not None:
100
99
  attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
101
100
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
101
+
102
102
  attn_weights = nn.functional.softmax(attn_weights, dim=-1)
103
103
 
104
104
  attn_output = torch.bmm(attn_weights, value_states)
@@ -109,7 +109,9 @@ class _WhisperAttention(WhisperAttention):
109
109
  attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
110
110
  attn_output = self.out_proj(attn_output)
111
111
 
112
- return attn_output, None, present_key_value
112
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
113
+
114
+ return attn_output, attn_weights, present_key_value
113
115
 
114
116
 
115
117
  class _WhisperSdpaAttention(WhisperSdpaAttention):
@@ -186,6 +188,7 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
186
188
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
187
189
  cache_position: Optional[torch.Tensor] = None,
188
190
  attn_impl: str = "eager",
191
+ output_attentions: bool = False,
189
192
  ) -> torch.Tensor:
190
193
  # Self Attention Block
191
194
  residual = hidden_states
@@ -205,14 +208,22 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
205
208
  residual = hidden_states
206
209
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
207
210
  cross_attn_past_key_value = past_key_value[2:] if past_key_value is not None else None
208
-
209
- hidden_states, _, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
210
- self.encoder_attn,
211
- hidden_states=hidden_states,
212
- key_value_states=encoder_hidden_states,
213
- past_key_value=cross_attn_past_key_value,
214
- cache_position=cache_position,
215
- )
211
+ if output_attentions:
212
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = _WhisperAttention.forward(
213
+ self.encoder_attn,
214
+ hidden_states=hidden_states,
215
+ key_value_states=encoder_hidden_states,
216
+ past_key_value=cross_attn_past_key_value,
217
+ cache_position=cache_position,
218
+ )
219
+ else:
220
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
221
+ self.encoder_attn,
222
+ hidden_states=hidden_states,
223
+ key_value_states=encoder_hidden_states,
224
+ past_key_value=cross_attn_past_key_value,
225
+ cache_position=cache_position,
226
+ )
216
227
  hidden_states = residual + hidden_states
217
228
  present_key_value = present_key_value + cross_attn_present_key_value
218
229
 
@@ -223,7 +234,7 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
223
234
  hidden_states = self.fc2(hidden_states)
224
235
  hidden_states = residual + hidden_states
225
236
 
226
- return hidden_states, present_key_value
237
+ return hidden_states, present_key_value, cross_attn_weights
227
238
 
228
239
 
229
240
  class _WhisperPositionalEmbedding(WhisperPositionalEmbedding):
@@ -243,6 +254,7 @@ class _WhisperDecoder(WhisperDecoder):
243
254
  past_key_values: Optional[torch.Tensor] = None,
244
255
  cache_position: Optional[torch.Tensor] = None,
245
256
  attn_impl: str = "eager",
257
+ output_attentions: bool = False,
246
258
  **kwargs,
247
259
  ):
248
260
  input_shape = input_ids.size()
@@ -266,6 +278,7 @@ class _WhisperDecoder(WhisperDecoder):
266
278
  )
267
279
 
268
280
  next_decoder_cache = ()
281
+ all_cross_attentions = () if output_attentions else None
269
282
  # iterate decoder_layer
270
283
  for idx, decoder_layer in enumerate(self.layers):
271
284
  past_key_value = past_key_values[idx] if past_key_values is not None else None
@@ -277,10 +290,13 @@ class _WhisperDecoder(WhisperDecoder):
277
290
  past_key_value=past_key_value,
278
291
  cache_position=cache_position,
279
292
  attn_impl=attn_impl,
293
+ output_attentions=output_attentions,
280
294
  )
281
295
  hidden_states = layer_outputs[0]
282
296
 
283
297
  next_decoder_cache += (layer_outputs[1],)
298
+ if output_attentions:
299
+ all_cross_attentions += (layer_outputs[2],)
284
300
 
285
301
  # layer_norm
286
302
  hidden_states = self.layer_norm(hidden_states)
@@ -288,17 +304,19 @@ class _WhisperDecoder(WhisperDecoder):
288
304
  return BaseModelOutputWithPastAndCrossAttentions(
289
305
  last_hidden_state=hidden_states,
290
306
  past_key_values=next_decoder_cache,
307
+ cross_attentions=all_cross_attentions,
291
308
  )
292
309
 
293
310
 
294
311
  class _WhisperDecoderWrapper(torch.nn.Module):
295
- def __init__(self, model):
312
+ def __init__(self, model, output_attentions: bool = False):
296
313
  super().__init__()
297
314
  self.proj_out = model.proj_out
298
315
  self.config = model.config
299
316
  self.decoder = model.get_decoder()
300
317
  self.num_layers = self.config.decoder_layers
301
318
  self.attn_impl = self.config._attn_implementation
319
+ self.output_attentions = output_attentions
302
320
 
303
321
  def forward(
304
322
  self,
@@ -329,6 +347,7 @@ class _WhisperDecoderWrapper(torch.nn.Module):
329
347
  past_key_values=kv_cache,
330
348
  encoder_hidden_states=torch.tensor([1]),
331
349
  attn_impl=self.attn_impl,
350
+ output_attentions=self.output_attentions,
332
351
  )
333
352
  sequence_output = decoder_outputs[0]
334
353
  lm_logits = self.proj_out(sequence_output)
@@ -341,7 +360,12 @@ class _WhisperDecoderWrapper(torch.nn.Module):
341
360
  self_kv_cache.append(past_key_values[i][1])
342
361
  self_kv_cache = torch.stack(self_kv_cache, dim=0)
343
362
 
344
- return lm_logits, self_kv_cache
363
+ if self.output_attentions:
364
+ # deocder's cross attention is used for token_timestamps
365
+ cross_attention = torch.stack(decoder_outputs[2], dim=0)
366
+ return lm_logits, self_kv_cache, cross_attention
367
+ else:
368
+ return lm_logits, self_kv_cache
345
369
 
346
370
 
347
371
  class _WhisperEncoderWrapper(torch.nn.Module):
@@ -363,6 +387,7 @@ class _WhisperEncoderWrapper(torch.nn.Module):
363
387
  input_features: Optional[torch.LongTensor] = None,
364
388
  ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
365
389
  encoder_outputs = self.encoder(input_features=input_features)
390
+
366
391
  last_hidden_states = encoder_outputs[0]
367
392
 
368
393
  encoder_batch_size = input_features.shape[0]
@@ -388,13 +413,15 @@ class _WhisperEncoderWrapper(torch.nn.Module):
388
413
  encoder_hidden_states=last_hidden_states,
389
414
  past_key_values=dummy_past_key_value,
390
415
  attn_impl=self.attn_impl,
416
+ output_attentions=False,
391
417
  )
392
418
 
393
419
  first_past_kv = decoder_outputs[1]
394
420
 
395
- encoder_kv = []
421
+ cross_kv = []
396
422
  for layer_out in first_past_kv: # for layer
397
- encoder_kv.append(torch.stack(layer_out[2:], dim=0))
398
- encoder_kv = torch.stack(encoder_kv, dim=0)
423
+ cross_kv.append(layer_out[2])
424
+ cross_kv.append(layer_out[3])
425
+ cross_kv = torch.stack(cross_kv, dim=0)
399
426
 
400
- return encoder_kv
427
+ return cross_kv
@@ -22,13 +22,13 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
25
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
26
26
 
27
27
  import torch
28
28
  from transformers import AutoModel, PretrainedConfig, PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
29
29
 
30
30
  from ....modeling_base import RBLNModel
31
- from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
31
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
32
 
33
33
 
34
34
  logger = logging.getLogger(__name__)
@@ -53,8 +53,7 @@ class RBLNXLMRobertaModel(RBLNModel):
53
53
  subfolder: str = "",
54
54
  local_files_only: bool = False,
55
55
  trust_remote_code: bool = False,
56
- rbln_config_kwargs: Optional[Dict[str, Any]] = None,
57
- rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
56
+ rbln_kwargs: Optional[Dict[str, Any]] = None,
58
57
  **kwargs,
59
58
  ) -> "PreTrainedModel":
60
59
  model: "PreTrainedModel" = super().get_pytorch_model(
@@ -66,8 +65,7 @@ class RBLNXLMRobertaModel(RBLNModel):
66
65
  subfolder=subfolder,
67
66
  local_files_only=local_files_only,
68
67
  trust_remote_code=trust_remote_code,
69
- rbln_config_kwargs=rbln_config_kwargs,
70
- rbln_constructor_kwargs=rbln_constructor_kwargs,
68
+ rbln_kwargs=rbln_kwargs,
71
69
  library_name="transformers",
72
70
  )
73
71
 
@@ -78,10 +76,12 @@ class RBLNXLMRobertaModel(RBLNModel):
78
76
  cls,
79
77
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
80
78
  model_config: Optional["PretrainedConfig"] = None,
81
- rbln_max_seq_len: Optional[int] = None,
82
- rbln_model_input_names: Optional[List[str]] = None,
83
- rbln_batch_size: Optional[int] = None,
79
+ rbln_kwargs={},
84
80
  ) -> RBLNConfig:
81
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
82
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
83
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
84
+
85
85
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
86
86
  model_config, "max_position_embeddings", None
87
87
  )
@@ -111,12 +111,15 @@ class RBLNXLMRobertaModel(RBLNModel):
111
111
  for model_input_name in rbln_model_input_names
112
112
  ]
113
113
 
114
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
115
- rbln_runtime_config.batch_size = rbln_batch_size
116
-
117
- meta = {"rbln_max_seq_len": rbln_max_seq_len}
114
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
118
115
 
119
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
116
+ rbln_config = RBLNConfig(
117
+ rbln_cls=cls.__name__,
118
+ compile_cfgs=[rbln_compile_config],
119
+ rbln_kwargs=rbln_kwargs,
120
+ )
121
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
122
+ return rbln_config
120
123
 
121
124
  def forward(
122
125
  self,