optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. optimum/rbln/__init__.py +37 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
  4. optimum/rbln/diffusers/models/controlnet.py +56 -40
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  10. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  14. optimum/rbln/modeling_alias.py +3 -3
  15. optimum/rbln/modeling_base.py +471 -231
  16. optimum/rbln/modeling_config.py +152 -77
  17. optimum/rbln/modeling_seq2seq.py +166 -77
  18. optimum/rbln/transformers/__init__.py +35 -1
  19. optimum/rbln/transformers/models/__init__.py +20 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  33. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  34. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  35. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  37. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  38. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
  41. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  42. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  43. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  44. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  45. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  46. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
  47. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  48. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  49. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  50. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
  51. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  52. optimum/rbln/utils/import_utils.py +36 -1
  53. optimum/rbln/utils/logging.py +82 -0
  54. optimum/rbln/utils/runtime_utils.py +33 -0
  55. optimum/rbln/utils/timer_utils.py +19 -0
  56. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
  57. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  58. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  59. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  60. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  61. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -49,6 +49,7 @@ class T5Encoder(T5Stack):
49
49
  input_ids: torch.Tensor,
50
50
  attention_mask: torch.Tensor,
51
51
  position_bias: torch.Tensor,
52
+ batch_ids: torch.Tensor = None,
52
53
  ) -> BaseModelOutput:
53
54
  hidden_states = self.embed_tokens(input_ids)
54
55
  extended_attention_mask = self.invert_attention_mask(attention_mask)
@@ -58,6 +59,7 @@ class T5Encoder(T5Stack):
58
59
  layer_module,
59
60
  hidden_states,
60
61
  position_bias=position_bias,
62
+ batch_ids=batch_ids,
61
63
  )
62
64
  hidden_states = layer_outputs[0]
63
65
  hidden_states = self.final_layer_norm(hidden_states)
@@ -75,6 +77,7 @@ class T5Decoder(T5Stack):
75
77
  position_bias: torch.Tensor,
76
78
  encoder_decoder_position_bias: torch.Tensor,
77
79
  cache_position: torch.Tensor,
80
+ batch_ids: torch.Tensor,
78
81
  ) -> BaseModelOutputWithPastAndCrossAttentions:
79
82
  hidden_states = self.embed_tokens(input_ids)
80
83
  extended_attention_mask = self.invert_attention_mask(attention_mask)
@@ -84,6 +87,7 @@ class T5Decoder(T5Stack):
84
87
  encoder_decoder_position_bias = encoder_decoder_position_bias + encoder_extended_attention_mask
85
88
 
86
89
  present_key_value_states = ()
90
+
87
91
  for layer_module, past_key_value in zip(self.block, past_key_values):
88
92
  layer_outputs = _T5Block.forward(
89
93
  layer_module,
@@ -93,6 +97,7 @@ class T5Decoder(T5Stack):
93
97
  encoder_decoder_position_bias=encoder_decoder_position_bias,
94
98
  past_key_value=past_key_value,
95
99
  cache_position=cache_position,
100
+ batch_ids=batch_ids,
96
101
  )
97
102
  hidden_states, present_key_value_state = layer_outputs[:2]
98
103
  present_key_value_states = present_key_value_states + (present_key_value_state,)
@@ -119,7 +124,9 @@ class T5EncoderWrapper(torch.nn.Module):
119
124
  self.decoder_max_length = None
120
125
  self.decoder_batch_size = 1
121
126
 
122
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
127
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
128
+ cross_key_value: torch.Tensor = None, batch_idx: torch.Tensor = None,
129
+ ) -> torch.Tensor:
123
130
  encoder_batch_size = input_ids.shape[0]
124
131
  decoder_batch_size = self.decoder_batch_size
125
132
  decoder_max_length = self.decoder_max_length or self.default_max_length
@@ -127,7 +134,7 @@ class T5EncoderWrapper(torch.nn.Module):
127
134
 
128
135
  attn_layer = self.encoder.block[0].layer[0].SelfAttention
129
136
  encoder_position_bias = T5Attention.compute_bias(attn_layer, encoder_max_length, encoder_max_length)
130
- encoder_outputs = T5Encoder.forward(self.encoder, input_ids, attention_mask, encoder_position_bias)
137
+ encoder_outputs = T5Encoder.forward(self.encoder, input_ids, attention_mask, encoder_position_bias, batch_ids=torch.tensor(0, dtype=torch.int32))
131
138
 
132
139
  attn_layer = self.decoder.block[0].layer[0].SelfAttention
133
140
  decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
@@ -153,7 +160,7 @@ class T5EncoderWrapper(torch.nn.Module):
153
160
  layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
154
161
  dummy_past_key_value.append(layer_pkv)
155
162
 
156
- decoder_attention_mask = torch.zeros(decoder_batch_size, decoder_max_length, dtype=torch.int64)
163
+ decoder_attention_mask = torch.zeros(decoder_batch_size, decoder_max_length, dtype=torch.float32)
157
164
  decoder_attention_mask[:, :1] = 1
158
165
 
159
166
  # Since first step of decoder has different graph to further step of it,
@@ -169,6 +176,7 @@ class T5EncoderWrapper(torch.nn.Module):
169
176
  encoder_attention_mask=attention_mask,
170
177
  past_key_values=dummy_past_key_value,
171
178
  cache_position=torch.tensor(0, dtype=torch.int32),
179
+ batch_ids=torch.tensor(0, dtype=torch.int32),
172
180
  )
173
181
 
174
182
  past_key_values = decoder_outputs.past_key_values
@@ -179,7 +187,9 @@ class T5EncoderWrapper(torch.nn.Module):
179
187
  cross_kv_cache.append(past_key_values[i][3])
180
188
  cross_kv_cache = torch.stack(cross_kv_cache, dim=0)
181
189
 
182
- return cross_kv_cache
190
+ cross_key_value = cross_key_value.slice_scatter(cross_kv_cache, dim=1, start=batch_idx, end=batch_idx+1)
191
+
192
+ return cross_key_value
183
193
 
184
194
 
185
195
  class T5DecoderWrapper(torch.nn.Module):
@@ -201,6 +211,7 @@ class T5DecoderWrapper(torch.nn.Module):
201
211
  attention_mask: torch.Tensor,
202
212
  encoder_attention_mask: torch.Tensor,
203
213
  cache_position: torch.Tensor,
214
+ batch_position: torch.Tensor,
204
215
  self_kv_cache: torch.Tensor,
205
216
  cross_kv_cache: torch.Tensor,
206
217
  ) -> Tuple[torch.Tensor]:
@@ -210,6 +221,11 @@ class T5DecoderWrapper(torch.nn.Module):
210
221
  encoder_max_length = self.encoder_max_length or self.default_max_length
211
222
  decoder_max_length = self.decoder_max_length or self.default_max_length
212
223
 
224
+ if input_ids.shape[1] == 1:
225
+ rbln_batch_position = None
226
+ else:
227
+ rbln_batch_position = batch_position
228
+
213
229
  kv_cache = ()
214
230
  for i in range(0, num_layers * 2, 2):
215
231
  kv_cache = kv_cache + (
@@ -223,7 +239,12 @@ class T5DecoderWrapper(torch.nn.Module):
223
239
 
224
240
  attn_layer = self.model.decoder.block[0].layer[0].SelfAttention
225
241
  _decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
226
- decoder_position_bias = _decoder_position_bias[:, :, cache_position].unsqueeze(2)
242
+
243
+ batch_decoder_position_bias = []
244
+ for i in range(input_ids.shape[0]):
245
+ batch_position_bias = _decoder_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
246
+ batch_decoder_position_bias.append(batch_position_bias)
247
+ decoder_position_bias = torch.cat(batch_decoder_position_bias, dim=0)
227
248
 
228
249
  attn_layer = self.model.decoder.block[0].layer[1].EncDecAttention
229
250
  encoder_decoder_position_bias = torch.zeros(1, attn_layer.n_heads, 1, encoder_max_length)
@@ -238,6 +259,7 @@ class T5DecoderWrapper(torch.nn.Module):
238
259
  encoder_decoder_position_bias=encoder_decoder_position_bias,
239
260
  past_key_values=kv_cache,
240
261
  cache_position=cache_position,
262
+ batch_ids=rbln_batch_position
241
263
  )
242
264
 
243
265
  past_key_values = decoder_outputs.past_key_values
@@ -255,7 +277,7 @@ class T5DecoderWrapper(torch.nn.Module):
255
277
 
256
278
  self_kv_cache = torch.stack(self_kv_cache, dim=0)
257
279
 
258
- return lm_logits, self_kv_cache
280
+ return lm_logits, self_kv_cache, batch_position
259
281
 
260
282
 
261
283
  class _T5Attention(T5Attention):
@@ -269,10 +291,10 @@ class _T5Attention(T5Attention):
269
291
  position_bias: torch.Tensor = None,
270
292
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
271
293
  cache_position: Optional[torch.Tensor] = None, # 현재 cache sequence 길이
294
+ batch_index: torch.Tensor = None,
272
295
  is_self_attn: Optional[bool] = None,
273
296
  ) -> Tuple[torch.Tensor]:
274
297
  batch_size = hidden_states.shape[0]
275
- cross_batch_size = key_value_states.shape[0] if not is_self_attn and cache_position == 0 else None
276
298
 
277
299
  def shape(states, batch_size):
278
300
  """projection"""
@@ -288,39 +310,72 @@ class _T5Attention(T5Attention):
288
310
  if is_self_attn:
289
311
  key_states = shape(self.k(hidden_states), batch_size)
290
312
  value_states = shape(self.v(hidden_states), batch_size)
291
- if past_key_value is not None:
292
- # decoder self attn
293
- cache_k = past_key_value[0].slice_scatter(
294
- key_states, dim=2, start=cache_position, end=cache_position + 1
295
- )
296
- cache_v = past_key_value[1].slice_scatter(
297
- value_states, dim=2, start=cache_position, end=cache_position + 1
298
- )
299
- past_key_value = (cache_k, cache_v)
300
- key_states, value_states = past_key_value
301
-
302
313
  else:
303
314
  # cross-attn
304
- if cache_position == 0:
305
- key_states = shape(self.k(key_value_states), cross_batch_size)
306
- value_states = shape(self.v(key_value_states), cross_batch_size)
315
+ if cache_position.dim() == 0 :
316
+ key_states = shape(self.k(key_value_states), key_value_states.shape[0])
317
+ value_states = shape(self.v(key_value_states), key_value_states.shape[0])
307
318
  past_key_value = key_states, value_states
308
319
  else:
309
320
  key_states = past_key_value[0]
310
321
  value_states = past_key_value[1]
311
322
 
312
- # compute scores
313
- scores = torch.matmul(query_states, key_states.transpose(3, 2))
314
- scores += position_bias
323
+ if (batch_index is None or batch_index == -1) and batch_size > 1:
324
+ all_key_states = []
325
+ all_value_states = []
326
+ all_attn_output = []
327
+
328
+ for b in range(batch_size):
329
+ batch_query_states = query_states[b].unsqueeze(0)
330
+ batch_key_states = key_states[b].unsqueeze(0)
331
+ batch_value_states = value_states[b].unsqueeze(0)
332
+
333
+ if is_self_attn and past_key_value is not None:
334
+ batch_key_states = past_key_value[0][b].unsqueeze(0).slice_scatter(
335
+ batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
336
+ )
337
+ batch_value_states = past_key_value[1][b].unsqueeze(0).slice_scatter(
338
+ batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
339
+ )
340
+
341
+ scores = torch.matmul(batch_query_states, batch_key_states.transpose(3, 2))
342
+ scores += position_bias[b]
343
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
344
+ scores
345
+ )
346
+ attn_output = unshape(torch.matmul(attn_weights, batch_value_states), 1)
347
+ all_key_states.append(batch_key_states)
348
+ all_value_states.append(batch_value_states)
349
+ all_attn_output.append(attn_output)
350
+
351
+ key_states = torch.cat(all_key_states, dim=0)
352
+ value_states = torch.cat(all_value_states, dim=0)
353
+ attn_output = torch.cat(all_attn_output, dim=0)
315
354
 
316
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
317
- scores
318
- ) # (batch_size, n_heads, seq_length, key_length)
355
+ else:
356
+ if batch_index is None or batch_index == -1:
357
+ batch_index = 0
319
358
 
320
- attn_output = unshape(torch.matmul(attn_weights, value_states), batch_size) # (batch_size, seq_length, dim)
321
- attn_output = self.o(attn_output)
359
+ if is_self_attn and past_key_value is not None:
360
+ key_states = past_key_value[0].slice_scatter(
361
+ key_states, dim=2, start=cache_position, end=cache_position + 1
362
+ )
363
+ value_states = past_key_value[1].slice_scatter(
364
+ value_states, dim=2, start=cache_position, end=cache_position + 1
365
+ )
366
+ # compute scores
367
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
368
+ scores += position_bias
369
+
370
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
371
+ scores
372
+ ) # (batch_size, n_heads, seq_length, key_length)
373
+
374
+ attn_output = unshape(torch.matmul(attn_weights, value_states), batch_size) # (batch_size, seq_length, dim)
322
375
 
323
- outputs = (attn_output,) + (past_key_value,)
376
+ attn_output = self.o(attn_output)
377
+ present_key_value = (key_states, value_states)
378
+ outputs = (attn_output,) + (present_key_value,)
324
379
  return outputs
325
380
 
326
381
 
@@ -331,6 +386,7 @@ class _T5LayerSelfAttention(T5LayerSelfAttention):
331
386
  position_bias: torch.Tensor = None,
332
387
  past_key_value: Tuple[torch.Tensor] = None,
333
388
  cache_position: Optional[torch.Tensor] = None,
389
+ batch_index: torch.Tensor = None,
334
390
  ):
335
391
  normed_hidden_states = self.layer_norm(hidden_states)
336
392
  attention_output = _T5Attention.forward(
@@ -339,6 +395,7 @@ class _T5LayerSelfAttention(T5LayerSelfAttention):
339
395
  position_bias=position_bias,
340
396
  past_key_value=past_key_value,
341
397
  cache_position=cache_position,
398
+ batch_index=batch_index,
342
399
  is_self_attn=True,
343
400
  )
344
401
 
@@ -356,6 +413,7 @@ class _T5LayerCrossAttention(T5LayerCrossAttention):
356
413
  position_bias: torch.Tensor = None,
357
414
  past_key_value: Tuple[torch.Tensor] = None,
358
415
  cache_position: Optional[torch.Tensor] = None,
416
+ batch_index: torch.Tensor = None,
359
417
  ):
360
418
  normed_hidden_states = self.layer_norm(hidden_states)
361
419
  attention_output = _T5Attention.forward(
@@ -365,6 +423,7 @@ class _T5LayerCrossAttention(T5LayerCrossAttention):
365
423
  position_bias=position_bias,
366
424
  past_key_value=past_key_value,
367
425
  cache_position=cache_position,
426
+ batch_index=batch_index,
368
427
  is_self_attn=False,
369
428
  )
370
429
 
@@ -383,6 +442,7 @@ class _T5Block(T5Block):
383
442
  encoder_decoder_position_bias=None,
384
443
  past_key_value=None,
385
444
  cache_position=None,
445
+ batch_ids=None,
386
446
  ):
387
447
  if past_key_value is not None:
388
448
  if not self.is_decoder:
@@ -403,13 +463,13 @@ class _T5Block(T5Block):
403
463
  cross_attn_past_key_value = past_key_value[2:]
404
464
  else:
405
465
  self_attn_past_key_value, cross_attn_past_key_value = None, None
406
-
407
466
  self_attention_outputs = _T5LayerSelfAttention.forward(
408
467
  self.layer[0],
409
468
  hidden_states=hidden_states,
410
469
  position_bias=position_bias,
411
470
  past_key_value=self_attn_past_key_value,
412
471
  cache_position=cache_position,
472
+ batch_index=batch_ids,
413
473
  )
414
474
 
415
475
  hidden_states, present_key_value_state = self_attention_outputs[:2]
@@ -423,6 +483,7 @@ class _T5Block(T5Block):
423
483
  position_bias=encoder_decoder_position_bias,
424
484
  past_key_value=cross_attn_past_key_value,
425
485
  cache_position=cache_position,
486
+ batch_index=batch_ids,
426
487
  )
427
488
  hidden_states = cross_attention_outputs[0]
428
489
  # Combine self attn and cross attn key value states
@@ -22,14 +22,14 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from typing import TYPE_CHECKING, Optional, Union
25
+ from typing import TYPE_CHECKING, Any, Dict, Union
26
26
 
27
27
  import torch
28
28
  from transformers import AutoModelForMaskedLM, PretrainedConfig, Wav2Vec2ForCTC
29
29
  from transformers.modeling_outputs import CausalLMOutput
30
30
 
31
31
  from ....modeling_base import RBLNModel
32
- from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
32
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
33
33
 
34
34
 
35
35
  logger = logging.getLogger(__name__)
@@ -78,10 +78,10 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
78
78
  cls,
79
79
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
80
80
  model_config: "PretrainedConfig",
81
- rbln_max_seq_len: Optional[int] = None,
82
- rbln_batch_size: Optional[int] = None,
81
+ rbln_kwargs: Dict[str, Any] = {},
83
82
  ) -> RBLNConfig:
84
- meta = {}
83
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
84
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
85
85
 
86
86
  if rbln_max_seq_len is None:
87
87
  for tokenizer in preprocessors:
@@ -91,8 +91,6 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
91
91
  if rbln_max_seq_len is None:
92
92
  raise ValueError("`rbln_max_seq_len` should be specified!")
93
93
 
94
- meta["rbln_max_seq_len"] = rbln_max_seq_len
95
-
96
94
  if rbln_batch_size is None:
97
95
  rbln_batch_size = 1
98
96
 
@@ -107,11 +105,19 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
107
105
  ),
108
106
  ]
109
107
 
110
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info, batch_size=rbln_batch_size)
108
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
109
+
110
+ rbln_config = RBLNConfig(
111
+ rbln_cls=cls.__name__,
112
+ compile_cfgs=[rbln_compile_config],
113
+ rbln_kwargs=rbln_kwargs,
114
+ )
111
115
 
112
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
113
- [rbln_runtime_config],
114
- _rbln_meta=meta,
116
+ rbln_config.model_cfg.update(
117
+ {
118
+ "max_seq_len": rbln_max_seq_len,
119
+ "batch_size": rbln_batch_size,
120
+ }
115
121
  )
116
122
 
117
123
  return rbln_config
@@ -0,0 +1,68 @@
1
+ import torch
2
+ from transformers import GenerationMixin
3
+ from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
4
+
5
+
6
+ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
7
+ """
8
+ This class is based on transformers version 4.44.2.
9
+ It uses the same generate() method, so it's crucial to maintain the inheritance order.
10
+ Ensure WhisperGenerationMixin is listed before GenerationMixin.
11
+ """
12
+
13
+ def _postprocess_outputs(
14
+ self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, *args, **kwargs
15
+ ):
16
+ # remove all previously passed decoder input ids
17
+
18
+ ################################## rbln_change for 4.40.2###################################
19
+ # 4.40.2 has no keyword shortform, it has seperate codes from generation_fallback
20
+ is_shortform = kwargs.get("is_shortform", False)
21
+ start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
22
+
23
+ if isinstance(seek_outputs, torch.Tensor):
24
+ seek_outputs = seek_outputs[:, start_idx:]
25
+ return seek_outputs, seek_outputs
26
+
27
+ ############## rbln validation#############
28
+ if return_token_timestamps and not self.rbln_token_timestamps:
29
+ raise RuntimeError(
30
+ "To use .generate() with return_token_timestamps=True, the model must be compiled with rbln_token_timestamps=True. "
31
+ "You can compile the model by calling .from_pretrained() with export=True and rbln_token_timestamps=True as keyword arguments, "
32
+ "or you can generate with return_token_timestamps=False."
33
+ )
34
+
35
+ if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
36
+ num_frames = getattr(generation_config, "num_frames", None)
37
+ seek_outputs["token_timestamps"] = self._extract_token_timestamps(
38
+ seek_outputs, generation_config.alignment_heads, num_frames=num_frames
39
+ )
40
+ seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
41
+
42
+ seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
43
+
44
+ def split_by_batch_index(values, key, batch_idx):
45
+ if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
46
+ return [v[batch_idx].cpu() for v in values]
47
+ if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
48
+ return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
49
+ elif key == "past_key_values":
50
+ # we don't save `past_key_values in rbln
51
+ return None
52
+
53
+ return values[batch_idx].cpu()
54
+
55
+ sequence_tokens = seek_outputs["sequences"]
56
+
57
+ ##################################### thkim change #############################################
58
+ valid_seekoutputs = []
59
+ for k, v in seek_outputs.items():
60
+ if v is not None and len(v) > 0 and v[0] is not None:
61
+ valid_seekoutputs.append((k, v))
62
+ seek_outputs = [
63
+ {k: split_by_batch_index(v, k, i) for k, v in valid_seekoutputs}
64
+ # {k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
65
+ for i in range(sequence_tokens.shape[0])
66
+ ]
67
+
68
+ return sequence_tokens, seek_outputs