optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. optimum/rbln/__init__.py +47 -9
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
  4. optimum/rbln/diffusers/models/controlnet.py +53 -43
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
  15. optimum/rbln/modeling_alias.py +6 -11
  16. optimum/rbln/modeling_base.py +467 -261
  17. optimum/rbln/modeling_config.py +199 -73
  18. optimum/rbln/transformers/__init__.py +43 -1
  19. optimum/rbln/transformers/models/__init__.py +23 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
  33. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  34. optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
  35. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  37. optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
  38. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  41. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  42. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  47. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  48. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  51. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
  52. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  53. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  54. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
  55. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  56. optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
  57. optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
  58. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
  59. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  60. optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
  61. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  62. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
  63. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  64. optimum/rbln/utils/import_utils.py +50 -1
  65. optimum/rbln/utils/logging.py +82 -0
  66. optimum/rbln/utils/runtime_utils.py +33 -0
  67. optimum/rbln/utils/timer_utils.py +43 -0
  68. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
  69. optimum_rbln-0.1.12.dist-info/RECORD +103 -0
  70. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
  71. optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
  72. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  73. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -30,7 +30,6 @@ import torch
30
30
  from transformers import (
31
31
  AutoModelForSpeechSeq2Seq,
32
32
  AutoProcessor,
33
- GenerationMixin,
34
33
  PretrainedConfig,
35
34
  WhisperForConditionalGeneration,
36
35
  WhisperModel,
@@ -38,8 +37,9 @@ from transformers import (
38
37
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
39
38
 
40
39
  from ....modeling_base import RBLNModel
41
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
40
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
42
41
  from ....utils.runtime_utils import RBLNPytorchRuntime
42
+ from .generation_whisper import RBLNWhisperGenerationMixin
43
43
  from .whisper_architecture import (
44
44
  _WhisperDecoderWrapper,
45
45
  _WhisperEncoderWrapper,
@@ -59,20 +59,47 @@ if TYPE_CHECKING:
59
59
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
60
60
  mandatory_members = ["main_input_name"]
61
61
 
62
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
63
- _ = super().forward(input_features=kwargs["input_features"])
64
- return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
62
+ def forward(self, input_features: torch.Tensor = None):
63
+ # backward compatibility transformers==4.40.2
64
+ # https://github.com/huggingface/transformers/blob/4fdf58afb72b0754da30037fc800b6044e7d9c99/src/transformers/pipelines/automatic_speech_recognition.py#L494
65
+
66
+ n_pad_to_batch = self.batch_size - input_features.shape[0]
67
+ if n_pad_to_batch > 0:
68
+ input_features = torch.nn.functional.pad(input_features, (0, 0, 0, 0, 0, n_pad_to_batch))
69
+
70
+ _ = super().forward(input_features=input_features)
71
+
72
+ # dummy output for generation
73
+ return BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
65
74
 
66
75
 
67
76
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
68
77
  mandatory_members = ["main_input_name"]
69
78
 
70
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
71
- outputs = super().forward(*args, **kwargs)
72
- return Seq2SeqLMOutput(logits=outputs)
79
+ def forward(
80
+ self,
81
+ decoder_input_ids: torch.Tensor = None,
82
+ decoder_attention_mask: torch.Tensor = None,
83
+ cache_position: torch.Tensor = None,
84
+ ):
85
+ inputs_bsz = decoder_input_ids.shape[0]
86
+ padded_bsz = self.batch_size - inputs_bsz
87
+ if padded_bsz > 0:
88
+ decoder_input_ids = torch.nn.functional.pad(decoder_input_ids, (0, 0, 0, padded_bsz))
89
+
90
+ outputs = super().forward(
91
+ decoder_input_ids=decoder_input_ids,
92
+ decoder_attention_mask=decoder_attention_mask,
93
+ cache_position=cache_position,
94
+ )
95
+
96
+ if isinstance(outputs, torch.Tensor):
97
+ return Seq2SeqLMOutput(logits=outputs[:inputs_bsz], cross_attentions=None)
98
+ else:
99
+ return Seq2SeqLMOutput(logits=outputs[0][:inputs_bsz], cross_attentions=outputs[1][:, :inputs_bsz])
73
100
 
74
101
 
75
- class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
102
+ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
76
103
  """
77
104
  The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
78
105
  This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -83,20 +110,30 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
83
110
  - compiling the resulting graph using the RBLN compiler.
84
111
  """
85
112
 
86
- model_type = "rbln_model"
87
113
  auto_model_class = AutoModelForSpeechSeq2Seq
88
114
  main_input_name = "input_ids"
89
115
 
90
116
  def __post_init__(self, **kwargs):
91
- self.batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].batch_size
92
- self.enc_max_seq_len = self.rbln_config.meta["input_max_length"]
93
- self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
117
+ super().__post_init__(**kwargs)
94
118
 
95
- self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
96
- self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
97
- self.forced_decoder_ids = self.config.forced_decoder_ids
119
+ self.batch_size = self.rbln_config.model_cfg["batch_size"]
120
+ self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
121
+ self.rbln_token_timestamps = self.rbln_config.model_cfg["token_timestamps"]
122
+
123
+ self.encoder = RBLNRuntimeEncoder(
124
+ runtime=self.model[0], main_input_name="input_features", batch_size=self.batch_size
125
+ )
126
+ self.decoder = RBLNRuntimeDecoder(
127
+ runtime=self.model[1], main_input_name="input_ids", batch_size=self.batch_size
128
+ )
98
129
 
99
- # used in GenerationMixin.generate()
130
+ # skip encoder & first decoder when language detected
131
+ self.is_language_detected = False
132
+ self.language_cross = None
133
+
134
+ # Used in GenerationMixin.generate()
135
+ # transformers/models/whisper/generation_whisper.py, line 505, in generate
136
+ # input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
100
137
  self.model = WhisperModel(self.config)
101
138
  self.pad_token_id = self.config.pad_token_id
102
139
 
@@ -127,63 +164,32 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
127
164
  # TODO(jongho): implement
128
165
  raise NotImplementedError
129
166
 
130
- def prepare_inputs_for_generation(
131
- self,
132
- input_ids,
133
- decoder_attention_mask=None,
134
- input_features=None, # Must be explicit
135
- **kwargs,
136
- ):
137
- max_seq_len = self.dec_max_seq_len
138
- cur_seq_len = input_ids.shape[-1]
139
- input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
140
- decoder_attention_mask = torch.zeros(self.batch_size, max_seq_len, dtype=torch.int64)
141
- decoder_attention_mask[:, :cur_seq_len] = 1
142
- cache_position = torch.tensor(cur_seq_len - 1, dtype=torch.int32)
143
-
144
- return {
145
- "decoder_input_ids": input_ids,
146
- "decoder_attention_mask": decoder_attention_mask,
147
- "cache_position": cache_position,
148
- }
149
-
150
- @classmethod
151
- def update_kwargs(cls, kwargs):
152
- kwargs.update(
153
- {
154
- "torchscript": True,
155
- "return_dict": False,
156
- "use_cache": True,
157
- }
158
- )
159
- return kwargs
160
-
161
167
  @classmethod
162
168
  @torch.inference_mode()
163
169
  def get_compiled_model(cls, model, rbln_config: RBLNConfig):
170
+ rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
164
171
  wrapped_encoder = _WhisperEncoderWrapper(model).eval()
165
- wrapped_decoder = _WhisperDecoderWrapper(model).eval()
172
+ wrapped_decoder = _WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps).eval()
166
173
 
167
- enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
168
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
174
+ enc_rbln_compile_config = rbln_config.compile_cfgs[0]
175
+ dec_rbln_compile_config = rbln_config.compile_cfgs[1]
169
176
 
170
- enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
171
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
177
+ enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=1)
178
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=1)
172
179
 
173
- enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs[0], check_trace=False)
180
+ enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
174
181
  dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
175
182
 
176
183
  enc_ir = rebel.torchscript_to_ir(
177
184
  enc_scripted_model,
178
- input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
179
- name=enc_rbln_runtime_config.rbln_mod_name,
185
+ input_names=[v[0] for v in enc_rbln_compile_config.input_info],
186
+ name=enc_rbln_compile_config.mod_name,
180
187
  )
181
188
  dec_ir = rebel.torchscript_to_ir(
182
189
  dec_scripted_model,
183
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
184
- name=dec_rbln_runtime_config.rbln_mod_name,
190
+ input_names=[v[0] for v in dec_rbln_compile_config.input_info],
191
+ name=dec_rbln_compile_config.mod_name,
185
192
  )
186
- dec_ir.batch_size = dec_rbln_runtime_config.batch_size
187
193
 
188
194
  # Caching encoder/decoder I/O
189
195
  connections = [
@@ -194,9 +200,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
194
200
  enc_ir,
195
201
  dec_ir,
196
202
  connections=connections,
197
- fusion=enc_rbln_runtime_config.fusion,
198
- npu=enc_rbln_runtime_config.npu,
199
- tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
203
+ fusion=enc_rbln_compile_config.fusion,
204
+ npu=enc_rbln_compile_config.npu,
205
+ tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
200
206
  )
201
207
  return compiled_model
202
208
 
@@ -205,42 +211,26 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
205
211
  cls,
206
212
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
207
213
  model_config: "PretrainedConfig",
208
- rbln_batch_size: Optional[int] = None,
214
+ rbln_kwargs: Dict[str, Any] = {},
209
215
  ) -> RBLNConfig:
210
- meta = {}
211
-
212
- input_max_length = 3000
213
- rbln_enc_num_mel_bins = getattr(model_config, "num_mel_bins", None)
214
- if rbln_enc_num_mel_bins is None:
215
- for feature_extractor in preprocessors:
216
- if hasattr(feature_extractor, "feature_size"):
217
- rbln_enc_num_mel_bins = feature_extractor.feature_size
218
- break
219
- raise ValueError("`rbln_enc_num_mel_bins` should be specified!")
220
-
221
- rbln_enc_max_seq_len = getattr(model_config, "max_source_positions", None)
222
- if rbln_enc_max_seq_len is None:
223
- raise ValueError("`rbln_enc_max_seq_len` should be specified!")
224
-
225
- rbln_dec_max_seq_len = getattr(model_config, "max_length", None)
226
- if rbln_dec_max_seq_len is None:
227
- raise ValueError("`rbln_dec_max_seq_len` should be specified!")
228
-
216
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
217
+ rbln_token_timestamps = rbln_kwargs.get("token_timestamps", False)
229
218
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
230
- decoder_batch_size = rbln_batch_size
231
219
 
232
- meta["rbln_dec_max_seq_len"] = rbln_dec_max_seq_len
233
- meta["rbln_enc_max_seq_len"] = rbln_enc_max_seq_len
234
- meta["num_mel_bins"] = rbln_enc_num_mel_bins
235
- meta["input_max_length"] = input_max_length
236
- meta["decoder_batch_size"] = decoder_batch_size
237
- meta["forced_decoder_ids"] = model_config.forced_decoder_ids
220
+ expected_seq_len = model_config.max_source_positions * 2
221
+ num_mel_bins = model_config.num_mel_bins
222
+ enc_max_seq_len = model_config.max_source_positions
223
+
224
+ # 'whisper-large-v3-turbo' doesn't have 'max_length', but PretrainedConfig have default value for the key 'max_length'
225
+ rbln_dec_max_seq_len = getattr(model_config, "max_target_positions", None)
226
+ if rbln_dec_max_seq_len is None:
227
+ rbln_dec_max_seq_len = model_config.max_length
238
228
 
239
229
  # model input info
240
- enc_input_info = [("input_features", [rbln_batch_size, rbln_enc_num_mel_bins, input_max_length], "float32")]
230
+ enc_input_info = [("input_features", [rbln_batch_size, num_mel_bins, expected_seq_len], "float32")]
241
231
  dec_input_info = [
242
- ("decoder_input_ids", [decoder_batch_size, 1], "int64"),
243
- ("decoder_attention_mask", [decoder_batch_size, rbln_dec_max_seq_len], "int64"),
232
+ ("decoder_input_ids", [rbln_batch_size, 1], "int64"),
233
+ ("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
244
234
  ("cache_position", [], "int32"),
245
235
  ]
246
236
  dec_input_info.extend(
@@ -249,7 +239,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
249
239
  "self_key_value_states",
250
240
  [
251
241
  model_config.decoder_layers * 2,
252
- decoder_batch_size,
242
+ rbln_batch_size,
253
243
  model_config.decoder_attention_heads,
254
244
  rbln_dec_max_seq_len,
255
245
  model_config.d_model // model_config.encoder_attention_heads,
@@ -266,7 +256,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
266
256
  model_config.decoder_layers * 2,
267
257
  rbln_batch_size,
268
258
  model_config.decoder_attention_heads,
269
- rbln_enc_max_seq_len,
259
+ enc_max_seq_len,
270
260
  model_config.d_model // model_config.encoder_attention_heads,
271
261
  ],
272
262
  "float32",
@@ -274,15 +264,21 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
274
264
  ]
275
265
  )
276
266
 
277
- enc_rbln_runtime_config = RBLNRuntimeConfig(rbln_mod_name="encoder", input_info=enc_input_info)
278
- dec_rbln_runtime_config = RBLNRuntimeConfig(rbln_mod_name="decoder", input_info=dec_input_info)
267
+ enc_rbln_compile_config = RBLNCompileConfig(mod_name="encoder", input_info=enc_input_info)
268
+ dec_rbln_compile_config = RBLNCompileConfig(mod_name="decoder", input_info=dec_input_info)
279
269
 
280
- enc_rbln_runtime_config.batch_size = rbln_batch_size
281
- dec_rbln_runtime_config.batch_size = decoder_batch_size
270
+ rbln_config = RBLNConfig(
271
+ rbln_cls=cls.__name__,
272
+ compile_cfgs=[enc_rbln_compile_config, dec_rbln_compile_config],
273
+ rbln_kwargs=rbln_kwargs,
274
+ )
282
275
 
283
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
284
- [enc_rbln_runtime_config, dec_rbln_runtime_config],
285
- _rbln_meta=meta,
276
+ rbln_config.model_cfg.update(
277
+ {
278
+ "batch_size": rbln_batch_size,
279
+ "dec_max_seq_len": rbln_dec_max_seq_len,
280
+ "token_timestamps": rbln_token_timestamps,
281
+ }
286
282
  )
287
283
 
288
284
  return rbln_config
@@ -297,18 +293,83 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
297
293
  compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
298
294
  ]
299
295
 
296
+ def prepare_inputs_for_generation(
297
+ self,
298
+ input_ids,
299
+ cache_position: Optional[torch.Tensor] = None,
300
+ attention_mask: Optional[torch.Tensor] = None, # need for support transformers>=4.45.0
301
+ **kwargs,
302
+ ):
303
+ """
304
+ whisper don't use attention_mask,
305
+ attention_mask (`torch.Tensor`)`, *optional*):
306
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
307
+ but it is not used. By default the silence in the input log mel spectrogram are ignored.
308
+ """
309
+ return {
310
+ "input_ids": input_ids,
311
+ "cache_position": cache_position,
312
+ }
313
+
314
+ # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L512
315
+ def _prepare_encoder_decoder_kwargs_for_generation(
316
+ self, inputs_tensor: torch.Tensor, model_kwargs, *args, **kwargs
317
+ ) -> Dict[str, Any]:
318
+ if not self.is_language_detected:
319
+ model_kwargs["encoder_outputs"] = self.encoder(input_features=inputs_tensor)
320
+ self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
321
+ else:
322
+ model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
323
+
324
+ return model_kwargs
325
+
300
326
  def forward(
301
327
  self,
302
- decoder_input_ids: Optional[torch.LongTensor] = None,
303
- decoder_attention_mask: Optional[torch.LongTensor] = None,
328
+ input_ids: Optional[torch.LongTensor] = None,
304
329
  cache_position: Optional[torch.Tensor] = None,
330
+ input_features: Optional[torch.Tensor] = None,
331
+ decoder_input_ids: Optional[torch.Tensor] = None,
332
+ encoder_outputs: Optional[Seq2SeqLMOutput] = None,
305
333
  **kwargs,
306
334
  ) -> Seq2SeqLMOutput:
307
- decoder_output = self.decoder(
308
- decoder_input_ids=decoder_input_ids,
309
- decoder_attention_mask=decoder_attention_mask,
310
- cache_position=cache_position,
311
- )
312
- lm_logits = decoder_output.logits
313
-
314
- return Seq2SeqLMOutput(logits=lm_logits)
335
+ # default decoder pass
336
+ if input_features is None and encoder_outputs is None:
337
+ cross_attentions = []
338
+ for step in cache_position:
339
+ # skip step 0 if language_detection has been processed
340
+ if step == 0 and self.is_language_detected:
341
+ cross_attentions.append(self.language_cross)
342
+ self.is_language_detected = False
343
+ else:
344
+ self.decoder_attention_mask[:, step] = 1
345
+ decoder_output = self.decoder(
346
+ decoder_input_ids=input_ids[:, step : step + 1].contiguous(),
347
+ decoder_attention_mask=self.decoder_attention_mask,
348
+ cache_position=step.to(torch.int32),
349
+ )
350
+ cross_attentions.append(decoder_output.cross_attentions)
351
+ lm_logits = decoder_output.logits
352
+
353
+ if self.rbln_token_timestamps:
354
+ cross_attentions = torch.cat(cross_attentions, dim=-2)
355
+ else:
356
+ cross_attentions = None
357
+
358
+ return Seq2SeqLMOutput(logits=lm_logits, cross_attentions=cross_attentions)
359
+
360
+ # detect language pass
361
+ # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/whisper/generation_whisper.py#L1442
362
+ else:
363
+ if encoder_outputs is None:
364
+ self.encoder(input_features=input_features.contiguous())
365
+ self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
366
+ self.is_language_detected = True
367
+ self.decoder_attention_mask[:, 0] = 1
368
+ decoder_output = self.decoder(
369
+ decoder_input_ids=decoder_input_ids.contiguous(),
370
+ decoder_attention_mask=self.decoder_attention_mask,
371
+ cache_position=torch.zeros([], dtype=torch.int32),
372
+ )
373
+ lm_logits = decoder_output.logits
374
+ self.language_cross = decoder_output.cross_attentions
375
+ return Seq2SeqLMOutput(logits=lm_logits)
@@ -55,7 +55,6 @@ class _WhisperAttention(WhisperAttention):
55
55
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
56
56
  attention_mask: Optional[torch.Tensor] = None,
57
57
  cache_position: Optional[torch.Tensor] = None,
58
- **kwargs,
59
58
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
60
59
  bsz, tgt_len, _ = hidden_states.size()
61
60
  is_cross_attention = key_value_states is not None
@@ -99,6 +98,7 @@ class _WhisperAttention(WhisperAttention):
99
98
  if attention_mask is not None:
100
99
  attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
101
100
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
101
+
102
102
  attn_weights = nn.functional.softmax(attn_weights, dim=-1)
103
103
 
104
104
  attn_output = torch.bmm(attn_weights, value_states)
@@ -109,7 +109,9 @@ class _WhisperAttention(WhisperAttention):
109
109
  attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
110
110
  attn_output = self.out_proj(attn_output)
111
111
 
112
- return attn_output, None, present_key_value
112
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
113
+
114
+ return attn_output, attn_weights, present_key_value
113
115
 
114
116
 
115
117
  class _WhisperSdpaAttention(WhisperSdpaAttention):
@@ -186,6 +188,7 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
186
188
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
187
189
  cache_position: Optional[torch.Tensor] = None,
188
190
  attn_impl: str = "eager",
191
+ output_attentions: bool = False,
189
192
  ) -> torch.Tensor:
190
193
  # Self Attention Block
191
194
  residual = hidden_states
@@ -205,14 +208,22 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
205
208
  residual = hidden_states
206
209
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
207
210
  cross_attn_past_key_value = past_key_value[2:] if past_key_value is not None else None
208
-
209
- hidden_states, _, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
210
- self.encoder_attn,
211
- hidden_states=hidden_states,
212
- key_value_states=encoder_hidden_states,
213
- past_key_value=cross_attn_past_key_value,
214
- cache_position=cache_position,
215
- )
211
+ if output_attentions:
212
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = _WhisperAttention.forward(
213
+ self.encoder_attn,
214
+ hidden_states=hidden_states,
215
+ key_value_states=encoder_hidden_states,
216
+ past_key_value=cross_attn_past_key_value,
217
+ cache_position=cache_position,
218
+ )
219
+ else:
220
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
221
+ self.encoder_attn,
222
+ hidden_states=hidden_states,
223
+ key_value_states=encoder_hidden_states,
224
+ past_key_value=cross_attn_past_key_value,
225
+ cache_position=cache_position,
226
+ )
216
227
  hidden_states = residual + hidden_states
217
228
  present_key_value = present_key_value + cross_attn_present_key_value
218
229
 
@@ -223,7 +234,7 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
223
234
  hidden_states = self.fc2(hidden_states)
224
235
  hidden_states = residual + hidden_states
225
236
 
226
- return hidden_states, present_key_value
237
+ return hidden_states, present_key_value, cross_attn_weights
227
238
 
228
239
 
229
240
  class _WhisperPositionalEmbedding(WhisperPositionalEmbedding):
@@ -243,6 +254,7 @@ class _WhisperDecoder(WhisperDecoder):
243
254
  past_key_values: Optional[torch.Tensor] = None,
244
255
  cache_position: Optional[torch.Tensor] = None,
245
256
  attn_impl: str = "eager",
257
+ output_attentions: bool = False,
246
258
  **kwargs,
247
259
  ):
248
260
  input_shape = input_ids.size()
@@ -266,6 +278,7 @@ class _WhisperDecoder(WhisperDecoder):
266
278
  )
267
279
 
268
280
  next_decoder_cache = ()
281
+ all_cross_attentions = () if output_attentions else None
269
282
  # iterate decoder_layer
270
283
  for idx, decoder_layer in enumerate(self.layers):
271
284
  past_key_value = past_key_values[idx] if past_key_values is not None else None
@@ -277,10 +290,13 @@ class _WhisperDecoder(WhisperDecoder):
277
290
  past_key_value=past_key_value,
278
291
  cache_position=cache_position,
279
292
  attn_impl=attn_impl,
293
+ output_attentions=output_attentions,
280
294
  )
281
295
  hidden_states = layer_outputs[0]
282
296
 
283
297
  next_decoder_cache += (layer_outputs[1],)
298
+ if output_attentions:
299
+ all_cross_attentions += (layer_outputs[2],)
284
300
 
285
301
  # layer_norm
286
302
  hidden_states = self.layer_norm(hidden_states)
@@ -288,17 +304,19 @@ class _WhisperDecoder(WhisperDecoder):
288
304
  return BaseModelOutputWithPastAndCrossAttentions(
289
305
  last_hidden_state=hidden_states,
290
306
  past_key_values=next_decoder_cache,
307
+ cross_attentions=all_cross_attentions,
291
308
  )
292
309
 
293
310
 
294
311
  class _WhisperDecoderWrapper(torch.nn.Module):
295
- def __init__(self, model):
312
+ def __init__(self, model, output_attentions: bool = False):
296
313
  super().__init__()
297
314
  self.proj_out = model.proj_out
298
315
  self.config = model.config
299
316
  self.decoder = model.get_decoder()
300
317
  self.num_layers = self.config.decoder_layers
301
318
  self.attn_impl = self.config._attn_implementation
319
+ self.output_attentions = output_attentions
302
320
 
303
321
  def forward(
304
322
  self,
@@ -329,6 +347,7 @@ class _WhisperDecoderWrapper(torch.nn.Module):
329
347
  past_key_values=kv_cache,
330
348
  encoder_hidden_states=torch.tensor([1]),
331
349
  attn_impl=self.attn_impl,
350
+ output_attentions=self.output_attentions,
332
351
  )
333
352
  sequence_output = decoder_outputs[0]
334
353
  lm_logits = self.proj_out(sequence_output)
@@ -341,7 +360,12 @@ class _WhisperDecoderWrapper(torch.nn.Module):
341
360
  self_kv_cache.append(past_key_values[i][1])
342
361
  self_kv_cache = torch.stack(self_kv_cache, dim=0)
343
362
 
344
- return lm_logits, self_kv_cache
363
+ if self.output_attentions:
364
+ # deocder's cross attention is used for token_timestamps
365
+ cross_attention = torch.stack(decoder_outputs[2], dim=0)
366
+ return lm_logits, self_kv_cache, cross_attention
367
+ else:
368
+ return lm_logits, self_kv_cache
345
369
 
346
370
 
347
371
  class _WhisperEncoderWrapper(torch.nn.Module):
@@ -363,6 +387,7 @@ class _WhisperEncoderWrapper(torch.nn.Module):
363
387
  input_features: Optional[torch.LongTensor] = None,
364
388
  ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
365
389
  encoder_outputs = self.encoder(input_features=input_features)
390
+
366
391
  last_hidden_states = encoder_outputs[0]
367
392
 
368
393
  encoder_batch_size = input_features.shape[0]
@@ -388,13 +413,15 @@ class _WhisperEncoderWrapper(torch.nn.Module):
388
413
  encoder_hidden_states=last_hidden_states,
389
414
  past_key_values=dummy_past_key_value,
390
415
  attn_impl=self.attn_impl,
416
+ output_attentions=False,
391
417
  )
392
418
 
393
419
  first_past_kv = decoder_outputs[1]
394
420
 
395
- encoder_kv = []
421
+ cross_kv = []
396
422
  for layer_out in first_past_kv: # for layer
397
- encoder_kv.append(torch.stack(layer_out[2:], dim=0))
398
- encoder_kv = torch.stack(encoder_kv, dim=0)
423
+ cross_kv.append(layer_out[2])
424
+ cross_kv.append(layer_out[3])
425
+ cross_kv = torch.stack(cross_kv, dim=0)
399
426
 
400
- return encoder_kv
427
+ return cross_kv
@@ -22,13 +22,13 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
25
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
26
26
 
27
27
  import torch
28
- from transformers import AutoModel, PretrainedConfig, PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
28
+ from transformers import PretrainedConfig, PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
29
29
 
30
30
  from ....modeling_base import RBLNModel
31
- from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
31
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
32
 
33
33
 
34
34
  logger = logging.getLogger(__name__)
@@ -38,7 +38,6 @@ if TYPE_CHECKING:
38
38
 
39
39
 
40
40
  class RBLNXLMRobertaModel(RBLNModel):
41
- auto_model_class = AutoModel # feature extraction
42
41
  original_model_class = XLMRobertaModel
43
42
  original_config_class = XLMRobertaConfig
44
43
 
@@ -53,8 +52,7 @@ class RBLNXLMRobertaModel(RBLNModel):
53
52
  subfolder: str = "",
54
53
  local_files_only: bool = False,
55
54
  trust_remote_code: bool = False,
56
- rbln_config_kwargs: Optional[Dict[str, Any]] = None,
57
- rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
55
+ rbln_kwargs: Optional[Dict[str, Any]] = None,
58
56
  **kwargs,
59
57
  ) -> "PreTrainedModel":
60
58
  model: "PreTrainedModel" = super().get_pytorch_model(
@@ -66,8 +64,7 @@ class RBLNXLMRobertaModel(RBLNModel):
66
64
  subfolder=subfolder,
67
65
  local_files_only=local_files_only,
68
66
  trust_remote_code=trust_remote_code,
69
- rbln_config_kwargs=rbln_config_kwargs,
70
- rbln_constructor_kwargs=rbln_constructor_kwargs,
67
+ rbln_kwargs=rbln_kwargs,
71
68
  library_name="transformers",
72
69
  )
73
70
 
@@ -78,10 +75,12 @@ class RBLNXLMRobertaModel(RBLNModel):
78
75
  cls,
79
76
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
80
77
  model_config: Optional["PretrainedConfig"] = None,
81
- rbln_max_seq_len: Optional[int] = None,
82
- rbln_model_input_names: Optional[List[str]] = None,
83
- rbln_batch_size: Optional[int] = None,
78
+ rbln_kwargs={},
84
79
  ) -> RBLNConfig:
80
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
81
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
82
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
83
+
85
84
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
86
85
  model_config, "max_position_embeddings", None
87
86
  )
@@ -111,12 +110,15 @@ class RBLNXLMRobertaModel(RBLNModel):
111
110
  for model_input_name in rbln_model_input_names
112
111
  ]
113
112
 
114
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
115
- rbln_runtime_config.batch_size = rbln_batch_size
116
-
117
- meta = {"rbln_max_seq_len": rbln_max_seq_len}
113
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
118
114
 
119
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
115
+ rbln_config = RBLNConfig(
116
+ rbln_cls=cls.__name__,
117
+ compile_cfgs=[rbln_compile_config],
118
+ rbln_kwargs=rbln_kwargs,
119
+ )
120
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
121
+ return rbln_config
120
122
 
121
123
  def forward(
122
124
  self,