optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. optimum/rbln/__init__.py +47 -9
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
  4. optimum/rbln/diffusers/models/controlnet.py +53 -43
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
  15. optimum/rbln/modeling_alias.py +6 -11
  16. optimum/rbln/modeling_base.py +467 -261
  17. optimum/rbln/modeling_config.py +199 -73
  18. optimum/rbln/transformers/__init__.py +43 -1
  19. optimum/rbln/transformers/models/__init__.py +23 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
  33. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  34. optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
  35. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  37. optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
  38. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  41. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  42. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  47. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  48. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  51. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
  52. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  53. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  54. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
  55. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  56. optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
  57. optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
  58. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
  59. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  60. optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
  61. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  62. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
  63. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  64. optimum/rbln/utils/import_utils.py +50 -1
  65. optimum/rbln/utils/logging.py +82 -0
  66. optimum/rbln/utils/runtime_utils.py +33 -0
  67. optimum/rbln/utils/timer_utils.py +43 -0
  68. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
  69. optimum_rbln-0.1.12.dist-info/RECORD +103 -0
  70. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
  71. optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
  72. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  73. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -43,12 +43,19 @@ if TYPE_CHECKING:
43
43
  from transformers import T5ForConditionalGeneration
44
44
 
45
45
 
46
+ class T5Wrapper:
47
+ def __init__(self, model):
48
+ self.encoder = T5EncoderWrapper(model)
49
+ self.decoder = T5DecoderWrapper(model)
50
+
51
+
46
52
  class T5Encoder(T5Stack):
47
53
  def forward(
48
54
  self,
49
55
  input_ids: torch.Tensor,
50
56
  attention_mask: torch.Tensor,
51
57
  position_bias: torch.Tensor,
58
+ batch_ids: torch.Tensor = None,
52
59
  ) -> BaseModelOutput:
53
60
  hidden_states = self.embed_tokens(input_ids)
54
61
  extended_attention_mask = self.invert_attention_mask(attention_mask)
@@ -58,6 +65,7 @@ class T5Encoder(T5Stack):
58
65
  layer_module,
59
66
  hidden_states,
60
67
  position_bias=position_bias,
68
+ batch_ids=batch_ids,
61
69
  )
62
70
  hidden_states = layer_outputs[0]
63
71
  hidden_states = self.final_layer_norm(hidden_states)
@@ -75,6 +83,7 @@ class T5Decoder(T5Stack):
75
83
  position_bias: torch.Tensor,
76
84
  encoder_decoder_position_bias: torch.Tensor,
77
85
  cache_position: torch.Tensor,
86
+ batch_ids: torch.Tensor,
78
87
  ) -> BaseModelOutputWithPastAndCrossAttentions:
79
88
  hidden_states = self.embed_tokens(input_ids)
80
89
  extended_attention_mask = self.invert_attention_mask(attention_mask)
@@ -84,6 +93,7 @@ class T5Decoder(T5Stack):
84
93
  encoder_decoder_position_bias = encoder_decoder_position_bias + encoder_extended_attention_mask
85
94
 
86
95
  present_key_value_states = ()
96
+
87
97
  for layer_module, past_key_value in zip(self.block, past_key_values):
88
98
  layer_outputs = _T5Block.forward(
89
99
  layer_module,
@@ -93,6 +103,7 @@ class T5Decoder(T5Stack):
93
103
  encoder_decoder_position_bias=encoder_decoder_position_bias,
94
104
  past_key_value=past_key_value,
95
105
  cache_position=cache_position,
106
+ batch_ids=batch_ids,
96
107
  )
97
108
  hidden_states, present_key_value_state = layer_outputs[:2]
98
109
  present_key_value_states = present_key_value_states + (present_key_value_state,)
@@ -117,17 +128,26 @@ class T5EncoderWrapper(torch.nn.Module):
117
128
  )
118
129
  self.encoder_max_length = None
119
130
  self.decoder_max_length = None
120
- self.decoder_batch_size = 1
121
131
 
122
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
123
- encoder_batch_size = input_ids.shape[0]
124
- decoder_batch_size = self.decoder_batch_size
132
+ def forward(
133
+ self,
134
+ input_ids: torch.Tensor,
135
+ attention_mask: torch.Tensor,
136
+ cross_key_value: torch.Tensor = None,
137
+ batch_idx: torch.Tensor = None,
138
+ ) -> torch.Tensor:
125
139
  decoder_max_length = self.decoder_max_length or self.default_max_length
126
140
  encoder_max_length = self.encoder_max_length or self.default_max_length
127
141
 
128
142
  attn_layer = self.encoder.block[0].layer[0].SelfAttention
129
143
  encoder_position_bias = T5Attention.compute_bias(attn_layer, encoder_max_length, encoder_max_length)
130
- encoder_outputs = T5Encoder.forward(self.encoder, input_ids, attention_mask, encoder_position_bias)
144
+ encoder_outputs = T5Encoder.forward(
145
+ self.encoder,
146
+ input_ids,
147
+ attention_mask,
148
+ encoder_position_bias,
149
+ batch_ids=torch.tensor(0, dtype=torch.int32),
150
+ )
131
151
 
132
152
  attn_layer = self.decoder.block[0].layer[0].SelfAttention
133
153
  decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
@@ -138,22 +158,14 @@ class T5EncoderWrapper(torch.nn.Module):
138
158
 
139
159
  dummy_past_key_value = []
140
160
  for i in range(self.config.num_layers):
141
- pkv_self_attn_key = torch.zeros(
142
- decoder_batch_size, self.config.num_heads, decoder_max_length, self.config.d_kv
143
- )
144
- pkv_self_attn_value = torch.zeros(
145
- decoder_batch_size, self.config.num_heads, decoder_max_length, self.config.d_kv
146
- )
147
- pkv_cross_attn_key = torch.zeros(
148
- encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
149
- )
150
- pkv_cross_attn_value = torch.zeros(
151
- encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
152
- )
161
+ pkv_self_attn_key = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
162
+ pkv_self_attn_value = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
163
+ pkv_cross_attn_key = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
164
+ pkv_cross_attn_value = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
153
165
  layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
154
166
  dummy_past_key_value.append(layer_pkv)
155
167
 
156
- decoder_attention_mask = torch.zeros(decoder_batch_size, decoder_max_length, dtype=torch.int64)
168
+ decoder_attention_mask = torch.zeros(1, decoder_max_length, dtype=torch.float32)
157
169
  decoder_attention_mask[:, :1] = 1
158
170
 
159
171
  # Since first step of decoder has different graph to further step of it,
@@ -161,7 +173,7 @@ class T5EncoderWrapper(torch.nn.Module):
161
173
  # TODO(jongho): Separate first-step-decoder.
162
174
  decoder_outputs = T5Decoder.forward(
163
175
  self.decoder,
164
- input_ids=torch.zeros(decoder_batch_size, 1, dtype=torch.int64),
176
+ input_ids=torch.zeros(1, 1, dtype=torch.int64),
165
177
  attention_mask=decoder_attention_mask,
166
178
  position_bias=decoder_position_bias,
167
179
  encoder_decoder_position_bias=encoder_decoder_position_bias,
@@ -169,6 +181,7 @@ class T5EncoderWrapper(torch.nn.Module):
169
181
  encoder_attention_mask=attention_mask,
170
182
  past_key_values=dummy_past_key_value,
171
183
  cache_position=torch.tensor(0, dtype=torch.int32),
184
+ batch_ids=torch.tensor(0, dtype=torch.int32),
172
185
  )
173
186
 
174
187
  past_key_values = decoder_outputs.past_key_values
@@ -179,7 +192,9 @@ class T5EncoderWrapper(torch.nn.Module):
179
192
  cross_kv_cache.append(past_key_values[i][3])
180
193
  cross_kv_cache = torch.stack(cross_kv_cache, dim=0)
181
194
 
182
- return cross_kv_cache
195
+ cross_key_value = cross_key_value.slice_scatter(cross_kv_cache, dim=1, start=batch_idx, end=batch_idx + 1)
196
+
197
+ return cross_key_value
183
198
 
184
199
 
185
200
  class T5DecoderWrapper(torch.nn.Module):
@@ -201,6 +216,7 @@ class T5DecoderWrapper(torch.nn.Module):
201
216
  attention_mask: torch.Tensor,
202
217
  encoder_attention_mask: torch.Tensor,
203
218
  cache_position: torch.Tensor,
219
+ batch_position: torch.Tensor,
204
220
  self_kv_cache: torch.Tensor,
205
221
  cross_kv_cache: torch.Tensor,
206
222
  ) -> Tuple[torch.Tensor]:
@@ -210,6 +226,11 @@ class T5DecoderWrapper(torch.nn.Module):
210
226
  encoder_max_length = self.encoder_max_length or self.default_max_length
211
227
  decoder_max_length = self.decoder_max_length or self.default_max_length
212
228
 
229
+ if input_ids.shape[1] == 1:
230
+ rbln_batch_position = None
231
+ else:
232
+ rbln_batch_position = batch_position
233
+
213
234
  kv_cache = ()
214
235
  for i in range(0, num_layers * 2, 2):
215
236
  kv_cache = kv_cache + (
@@ -223,7 +244,13 @@ class T5DecoderWrapper(torch.nn.Module):
223
244
 
224
245
  attn_layer = self.model.decoder.block[0].layer[0].SelfAttention
225
246
  _decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
226
- decoder_position_bias = _decoder_position_bias[:, :, cache_position].unsqueeze(2)
247
+
248
+ # position_bias need to compute with batch (for cb)
249
+ batch_decoder_position_bias = []
250
+ for i in range(input_ids.shape[0]):
251
+ batch_position_bias = _decoder_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
252
+ batch_decoder_position_bias.append(batch_position_bias)
253
+ decoder_position_bias = torch.cat(batch_decoder_position_bias, dim=0)
227
254
 
228
255
  attn_layer = self.model.decoder.block[0].layer[1].EncDecAttention
229
256
  encoder_decoder_position_bias = torch.zeros(1, attn_layer.n_heads, 1, encoder_max_length)
@@ -238,6 +265,7 @@ class T5DecoderWrapper(torch.nn.Module):
238
265
  encoder_decoder_position_bias=encoder_decoder_position_bias,
239
266
  past_key_values=kv_cache,
240
267
  cache_position=cache_position,
268
+ batch_ids=rbln_batch_position,
241
269
  )
242
270
 
243
271
  past_key_values = decoder_outputs.past_key_values
@@ -255,7 +283,7 @@ class T5DecoderWrapper(torch.nn.Module):
255
283
 
256
284
  self_kv_cache = torch.stack(self_kv_cache, dim=0)
257
285
 
258
- return lm_logits, self_kv_cache
286
+ return lm_logits, self_kv_cache, batch_position
259
287
 
260
288
 
261
289
  class _T5Attention(T5Attention):
@@ -269,10 +297,10 @@ class _T5Attention(T5Attention):
269
297
  position_bias: torch.Tensor = None,
270
298
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
271
299
  cache_position: Optional[torch.Tensor] = None, # 현재 cache sequence 길이
300
+ batch_index: torch.Tensor = None,
272
301
  is_self_attn: Optional[bool] = None,
273
302
  ) -> Tuple[torch.Tensor]:
274
303
  batch_size = hidden_states.shape[0]
275
- cross_batch_size = key_value_states.shape[0] if not is_self_attn and cache_position == 0 else None
276
304
 
277
305
  def shape(states, batch_size):
278
306
  """projection"""
@@ -288,39 +316,80 @@ class _T5Attention(T5Attention):
288
316
  if is_self_attn:
289
317
  key_states = shape(self.k(hidden_states), batch_size)
290
318
  value_states = shape(self.v(hidden_states), batch_size)
291
- if past_key_value is not None:
292
- # decoder self attn
293
- cache_k = past_key_value[0].slice_scatter(
294
- key_states, dim=2, start=cache_position, end=cache_position + 1
295
- )
296
- cache_v = past_key_value[1].slice_scatter(
297
- value_states, dim=2, start=cache_position, end=cache_position + 1
298
- )
299
- past_key_value = (cache_k, cache_v)
300
- key_states, value_states = past_key_value
301
-
302
319
  else:
303
320
  # cross-attn
304
- if cache_position == 0:
305
- key_states = shape(self.k(key_value_states), cross_batch_size)
306
- value_states = shape(self.v(key_value_states), cross_batch_size)
321
+ if cache_position.dim() == 0:
322
+ key_states = shape(self.k(key_value_states), key_value_states.shape[0])
323
+ value_states = shape(self.v(key_value_states), key_value_states.shape[0])
307
324
  past_key_value = key_states, value_states
308
325
  else:
309
326
  key_states = past_key_value[0]
310
327
  value_states = past_key_value[1]
311
328
 
312
- # compute scores
313
- scores = torch.matmul(query_states, key_states.transpose(3, 2))
314
- scores += position_bias
329
+ if (batch_index is None or batch_index == -1) and batch_size > 1:
330
+ all_key_states = []
331
+ all_value_states = []
332
+ all_attn_output = []
333
+
334
+ for b in range(batch_size):
335
+ batch_query_states = query_states[b].unsqueeze(0)
336
+ batch_key_states = key_states[b].unsqueeze(0)
337
+ batch_value_states = value_states[b].unsqueeze(0)
338
+
339
+ if is_self_attn and past_key_value is not None:
340
+ batch_key_states = (
341
+ past_key_value[0][b]
342
+ .unsqueeze(0)
343
+ .slice_scatter(
344
+ batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
345
+ )
346
+ )
347
+ batch_value_states = (
348
+ past_key_value[1][b]
349
+ .unsqueeze(0)
350
+ .slice_scatter(
351
+ batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
352
+ )
353
+ )
354
+
355
+ scores = torch.matmul(batch_query_states, batch_key_states.transpose(3, 2))
356
+ scores += position_bias[b]
357
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
358
+ attn_output = unshape(torch.matmul(attn_weights, batch_value_states), 1)
359
+ all_key_states.append(batch_key_states)
360
+ all_value_states.append(batch_value_states)
361
+ all_attn_output.append(attn_output)
362
+
363
+ key_states = torch.cat(all_key_states, dim=0)
364
+ value_states = torch.cat(all_value_states, dim=0)
365
+ attn_output = torch.cat(all_attn_output, dim=0)
315
366
 
316
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
317
- scores
318
- ) # (batch_size, n_heads, seq_length, key_length)
367
+ else:
368
+ if batch_index is None or batch_index == -1:
369
+ batch_index = 0
319
370
 
320
- attn_output = unshape(torch.matmul(attn_weights, value_states), batch_size) # (batch_size, seq_length, dim)
321
- attn_output = self.o(attn_output)
371
+ if is_self_attn and past_key_value is not None:
372
+ key_states = past_key_value[0].slice_scatter(
373
+ key_states, dim=2, start=cache_position, end=cache_position + 1
374
+ )
375
+ value_states = past_key_value[1].slice_scatter(
376
+ value_states, dim=2, start=cache_position, end=cache_position + 1
377
+ )
378
+ # compute scores
379
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
380
+ scores += position_bias
381
+
382
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
383
+ scores
384
+ ) # (batch_size, n_heads, seq_length, key_length)
322
385
 
323
- outputs = (attn_output,) + (past_key_value,)
386
+ attn_output = unshape(
387
+ torch.matmul(attn_weights, value_states), batch_size
388
+ ) # (batch_size, seq_length, dim)
389
+
390
+ attn_output = self.o(attn_output)
391
+ present_key_value = (key_states, value_states)
392
+ outputs = (attn_output,) + (present_key_value,)
324
393
  return outputs
325
394
 
326
395
 
@@ -331,6 +400,7 @@ class _T5LayerSelfAttention(T5LayerSelfAttention):
331
400
  position_bias: torch.Tensor = None,
332
401
  past_key_value: Tuple[torch.Tensor] = None,
333
402
  cache_position: Optional[torch.Tensor] = None,
403
+ batch_index: torch.Tensor = None,
334
404
  ):
335
405
  normed_hidden_states = self.layer_norm(hidden_states)
336
406
  attention_output = _T5Attention.forward(
@@ -339,6 +409,7 @@ class _T5LayerSelfAttention(T5LayerSelfAttention):
339
409
  position_bias=position_bias,
340
410
  past_key_value=past_key_value,
341
411
  cache_position=cache_position,
412
+ batch_index=batch_index,
342
413
  is_self_attn=True,
343
414
  )
344
415
 
@@ -356,6 +427,7 @@ class _T5LayerCrossAttention(T5LayerCrossAttention):
356
427
  position_bias: torch.Tensor = None,
357
428
  past_key_value: Tuple[torch.Tensor] = None,
358
429
  cache_position: Optional[torch.Tensor] = None,
430
+ batch_index: torch.Tensor = None,
359
431
  ):
360
432
  normed_hidden_states = self.layer_norm(hidden_states)
361
433
  attention_output = _T5Attention.forward(
@@ -365,6 +437,7 @@ class _T5LayerCrossAttention(T5LayerCrossAttention):
365
437
  position_bias=position_bias,
366
438
  past_key_value=past_key_value,
367
439
  cache_position=cache_position,
440
+ batch_index=batch_index,
368
441
  is_self_attn=False,
369
442
  )
370
443
 
@@ -383,6 +456,7 @@ class _T5Block(T5Block):
383
456
  encoder_decoder_position_bias=None,
384
457
  past_key_value=None,
385
458
  cache_position=None,
459
+ batch_ids=None,
386
460
  ):
387
461
  if past_key_value is not None:
388
462
  if not self.is_decoder:
@@ -403,13 +477,13 @@ class _T5Block(T5Block):
403
477
  cross_attn_past_key_value = past_key_value[2:]
404
478
  else:
405
479
  self_attn_past_key_value, cross_attn_past_key_value = None, None
406
-
407
480
  self_attention_outputs = _T5LayerSelfAttention.forward(
408
481
  self.layer[0],
409
482
  hidden_states=hidden_states,
410
483
  position_bias=position_bias,
411
484
  past_key_value=self_attn_past_key_value,
412
485
  cache_position=cache_position,
486
+ batch_index=batch_ids,
413
487
  )
414
488
 
415
489
  hidden_states, present_key_value_state = self_attention_outputs[:2]
@@ -423,6 +497,7 @@ class _T5Block(T5Block):
423
497
  position_bias=encoder_decoder_position_bias,
424
498
  past_key_value=cross_attn_past_key_value,
425
499
  cache_position=cache_position,
500
+ batch_index=batch_ids,
426
501
  )
427
502
  hidden_states = cross_attention_outputs[0]
428
503
  # Combine self attn and cross attn key value states
@@ -22,14 +22,14 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from typing import TYPE_CHECKING, Optional, Union
25
+ from typing import TYPE_CHECKING, Any, Dict, Union
26
26
 
27
27
  import torch
28
28
  from transformers import AutoModelForMaskedLM, PretrainedConfig, Wav2Vec2ForCTC
29
29
  from transformers.modeling_outputs import CausalLMOutput
30
30
 
31
31
  from ....modeling_base import RBLNModel
32
- from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
32
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
33
33
 
34
34
 
35
35
  logger = logging.getLogger(__name__)
@@ -65,7 +65,6 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
65
65
  - compiling the resulting graph using the RBLN compiler.
66
66
  """
67
67
 
68
- model_type = "rbln_model"
69
68
  main_input_name = "input_values"
70
69
  auto_model_class = AutoModelForMaskedLM
71
70
 
@@ -78,10 +77,10 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
78
77
  cls,
79
78
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
80
79
  model_config: "PretrainedConfig",
81
- rbln_max_seq_len: Optional[int] = None,
82
- rbln_batch_size: Optional[int] = None,
80
+ rbln_kwargs: Dict[str, Any] = {},
83
81
  ) -> RBLNConfig:
84
- meta = {}
82
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
83
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
85
84
 
86
85
  if rbln_max_seq_len is None:
87
86
  for tokenizer in preprocessors:
@@ -91,8 +90,6 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
91
90
  if rbln_max_seq_len is None:
92
91
  raise ValueError("`rbln_max_seq_len` should be specified!")
93
92
 
94
- meta["rbln_max_seq_len"] = rbln_max_seq_len
95
-
96
93
  if rbln_batch_size is None:
97
94
  rbln_batch_size = 1
98
95
 
@@ -107,11 +104,19 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
107
104
  ),
108
105
  ]
109
106
 
110
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info, batch_size=rbln_batch_size)
107
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
108
+
109
+ rbln_config = RBLNConfig(
110
+ rbln_cls=cls.__name__,
111
+ compile_cfgs=[rbln_compile_config],
112
+ rbln_kwargs=rbln_kwargs,
113
+ )
111
114
 
112
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
113
- [rbln_runtime_config],
114
- _rbln_meta=meta,
115
+ rbln_config.model_cfg.update(
116
+ {
117
+ "max_seq_len": rbln_max_seq_len,
118
+ "batch_size": rbln_batch_size,
119
+ }
115
120
  )
116
121
 
117
122
  return rbln_config
@@ -0,0 +1,68 @@
1
+ import torch
2
+ from transformers import GenerationMixin
3
+ from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
4
+
5
+
6
+ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
7
+ """
8
+ This class is based on transformers version 4.44.2.
9
+ It uses the same generate() method, so it's crucial to maintain the inheritance order.
10
+ Ensure WhisperGenerationMixin is listed before GenerationMixin.
11
+ """
12
+
13
+ def _postprocess_outputs(
14
+ self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, *args, **kwargs
15
+ ):
16
+ # remove all previously passed decoder input ids
17
+
18
+ ################################## rbln_change for 4.40.2###################################
19
+ # 4.40.2 has no keyword shortform, it has seperate codes from generation_fallback
20
+ is_shortform = kwargs.get("is_shortform", False)
21
+ start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
22
+
23
+ if isinstance(seek_outputs, torch.Tensor):
24
+ seek_outputs = seek_outputs[:, start_idx:]
25
+ return seek_outputs, seek_outputs
26
+
27
+ ############## rbln validation#############
28
+ if return_token_timestamps and not self.rbln_token_timestamps:
29
+ raise RuntimeError(
30
+ "To use .generate() with return_token_timestamps=True, the model must be compiled with rbln_token_timestamps=True. "
31
+ "You can compile the model by calling .from_pretrained() with export=True and rbln_token_timestamps=True as keyword arguments, "
32
+ "or you can generate with return_token_timestamps=False."
33
+ )
34
+
35
+ if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
36
+ num_frames = getattr(generation_config, "num_frames", None)
37
+ seek_outputs["token_timestamps"] = self._extract_token_timestamps(
38
+ seek_outputs, generation_config.alignment_heads, num_frames=num_frames
39
+ )
40
+ seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
41
+
42
+ seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
43
+
44
+ def split_by_batch_index(values, key, batch_idx):
45
+ if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
46
+ return [v[batch_idx].cpu() for v in values]
47
+ if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
48
+ return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
49
+ elif key == "past_key_values":
50
+ # we don't save `past_key_values in rbln
51
+ return None
52
+
53
+ return values[batch_idx].cpu()
54
+
55
+ sequence_tokens = seek_outputs["sequences"]
56
+
57
+ ##################################### thkim change #############################################
58
+ valid_seekoutputs = []
59
+ for k, v in seek_outputs.items():
60
+ if v is not None and len(v) > 0 and v[0] is not None:
61
+ valid_seekoutputs.append((k, v))
62
+ seek_outputs = [
63
+ {k: split_by_batch_index(v, k, i) for k, v in valid_seekoutputs}
64
+ # {k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
65
+ for i in range(sequence_tokens.shape[0])
66
+ ]
67
+
68
+ return sequence_tokens, seek_outputs