optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +37 -2
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
- optimum/rbln/diffusers/models/controlnet.py +56 -40
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
- optimum/rbln/modeling_alias.py +3 -3
- optimum/rbln/modeling_base.py +471 -231
- optimum/rbln/modeling_config.py +152 -77
- optimum/rbln/modeling_seq2seq.py +166 -77
- optimum/rbln/transformers/__init__.py +35 -1
- optimum/rbln/transformers/models/__init__.py +20 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
- optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +36 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +19 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
- optimum_rbln-0.1.11.dist-info/RECORD +93 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -49,6 +49,7 @@ class T5Encoder(T5Stack):
|
|
49
49
|
input_ids: torch.Tensor,
|
50
50
|
attention_mask: torch.Tensor,
|
51
51
|
position_bias: torch.Tensor,
|
52
|
+
batch_ids: torch.Tensor = None,
|
52
53
|
) -> BaseModelOutput:
|
53
54
|
hidden_states = self.embed_tokens(input_ids)
|
54
55
|
extended_attention_mask = self.invert_attention_mask(attention_mask)
|
@@ -58,6 +59,7 @@ class T5Encoder(T5Stack):
|
|
58
59
|
layer_module,
|
59
60
|
hidden_states,
|
60
61
|
position_bias=position_bias,
|
62
|
+
batch_ids=batch_ids,
|
61
63
|
)
|
62
64
|
hidden_states = layer_outputs[0]
|
63
65
|
hidden_states = self.final_layer_norm(hidden_states)
|
@@ -75,6 +77,7 @@ class T5Decoder(T5Stack):
|
|
75
77
|
position_bias: torch.Tensor,
|
76
78
|
encoder_decoder_position_bias: torch.Tensor,
|
77
79
|
cache_position: torch.Tensor,
|
80
|
+
batch_ids: torch.Tensor,
|
78
81
|
) -> BaseModelOutputWithPastAndCrossAttentions:
|
79
82
|
hidden_states = self.embed_tokens(input_ids)
|
80
83
|
extended_attention_mask = self.invert_attention_mask(attention_mask)
|
@@ -84,6 +87,7 @@ class T5Decoder(T5Stack):
|
|
84
87
|
encoder_decoder_position_bias = encoder_decoder_position_bias + encoder_extended_attention_mask
|
85
88
|
|
86
89
|
present_key_value_states = ()
|
90
|
+
|
87
91
|
for layer_module, past_key_value in zip(self.block, past_key_values):
|
88
92
|
layer_outputs = _T5Block.forward(
|
89
93
|
layer_module,
|
@@ -93,6 +97,7 @@ class T5Decoder(T5Stack):
|
|
93
97
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
94
98
|
past_key_value=past_key_value,
|
95
99
|
cache_position=cache_position,
|
100
|
+
batch_ids=batch_ids,
|
96
101
|
)
|
97
102
|
hidden_states, present_key_value_state = layer_outputs[:2]
|
98
103
|
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
@@ -119,7 +124,9 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
119
124
|
self.decoder_max_length = None
|
120
125
|
self.decoder_batch_size = 1
|
121
126
|
|
122
|
-
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor
|
127
|
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
|
128
|
+
cross_key_value: torch.Tensor = None, batch_idx: torch.Tensor = None,
|
129
|
+
) -> torch.Tensor:
|
123
130
|
encoder_batch_size = input_ids.shape[0]
|
124
131
|
decoder_batch_size = self.decoder_batch_size
|
125
132
|
decoder_max_length = self.decoder_max_length or self.default_max_length
|
@@ -127,7 +134,7 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
127
134
|
|
128
135
|
attn_layer = self.encoder.block[0].layer[0].SelfAttention
|
129
136
|
encoder_position_bias = T5Attention.compute_bias(attn_layer, encoder_max_length, encoder_max_length)
|
130
|
-
encoder_outputs = T5Encoder.forward(self.encoder, input_ids, attention_mask, encoder_position_bias)
|
137
|
+
encoder_outputs = T5Encoder.forward(self.encoder, input_ids, attention_mask, encoder_position_bias, batch_ids=torch.tensor(0, dtype=torch.int32))
|
131
138
|
|
132
139
|
attn_layer = self.decoder.block[0].layer[0].SelfAttention
|
133
140
|
decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
|
@@ -153,7 +160,7 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
153
160
|
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
154
161
|
dummy_past_key_value.append(layer_pkv)
|
155
162
|
|
156
|
-
decoder_attention_mask = torch.zeros(decoder_batch_size, decoder_max_length, dtype=torch.
|
163
|
+
decoder_attention_mask = torch.zeros(decoder_batch_size, decoder_max_length, dtype=torch.float32)
|
157
164
|
decoder_attention_mask[:, :1] = 1
|
158
165
|
|
159
166
|
# Since first step of decoder has different graph to further step of it,
|
@@ -169,6 +176,7 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
169
176
|
encoder_attention_mask=attention_mask,
|
170
177
|
past_key_values=dummy_past_key_value,
|
171
178
|
cache_position=torch.tensor(0, dtype=torch.int32),
|
179
|
+
batch_ids=torch.tensor(0, dtype=torch.int32),
|
172
180
|
)
|
173
181
|
|
174
182
|
past_key_values = decoder_outputs.past_key_values
|
@@ -179,7 +187,9 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
179
187
|
cross_kv_cache.append(past_key_values[i][3])
|
180
188
|
cross_kv_cache = torch.stack(cross_kv_cache, dim=0)
|
181
189
|
|
182
|
-
|
190
|
+
cross_key_value = cross_key_value.slice_scatter(cross_kv_cache, dim=1, start=batch_idx, end=batch_idx+1)
|
191
|
+
|
192
|
+
return cross_key_value
|
183
193
|
|
184
194
|
|
185
195
|
class T5DecoderWrapper(torch.nn.Module):
|
@@ -201,6 +211,7 @@ class T5DecoderWrapper(torch.nn.Module):
|
|
201
211
|
attention_mask: torch.Tensor,
|
202
212
|
encoder_attention_mask: torch.Tensor,
|
203
213
|
cache_position: torch.Tensor,
|
214
|
+
batch_position: torch.Tensor,
|
204
215
|
self_kv_cache: torch.Tensor,
|
205
216
|
cross_kv_cache: torch.Tensor,
|
206
217
|
) -> Tuple[torch.Tensor]:
|
@@ -210,6 +221,11 @@ class T5DecoderWrapper(torch.nn.Module):
|
|
210
221
|
encoder_max_length = self.encoder_max_length or self.default_max_length
|
211
222
|
decoder_max_length = self.decoder_max_length or self.default_max_length
|
212
223
|
|
224
|
+
if input_ids.shape[1] == 1:
|
225
|
+
rbln_batch_position = None
|
226
|
+
else:
|
227
|
+
rbln_batch_position = batch_position
|
228
|
+
|
213
229
|
kv_cache = ()
|
214
230
|
for i in range(0, num_layers * 2, 2):
|
215
231
|
kv_cache = kv_cache + (
|
@@ -223,7 +239,12 @@ class T5DecoderWrapper(torch.nn.Module):
|
|
223
239
|
|
224
240
|
attn_layer = self.model.decoder.block[0].layer[0].SelfAttention
|
225
241
|
_decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
|
226
|
-
|
242
|
+
|
243
|
+
batch_decoder_position_bias = []
|
244
|
+
for i in range(input_ids.shape[0]):
|
245
|
+
batch_position_bias = _decoder_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
|
246
|
+
batch_decoder_position_bias.append(batch_position_bias)
|
247
|
+
decoder_position_bias = torch.cat(batch_decoder_position_bias, dim=0)
|
227
248
|
|
228
249
|
attn_layer = self.model.decoder.block[0].layer[1].EncDecAttention
|
229
250
|
encoder_decoder_position_bias = torch.zeros(1, attn_layer.n_heads, 1, encoder_max_length)
|
@@ -238,6 +259,7 @@ class T5DecoderWrapper(torch.nn.Module):
|
|
238
259
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
239
260
|
past_key_values=kv_cache,
|
240
261
|
cache_position=cache_position,
|
262
|
+
batch_ids=rbln_batch_position
|
241
263
|
)
|
242
264
|
|
243
265
|
past_key_values = decoder_outputs.past_key_values
|
@@ -255,7 +277,7 @@ class T5DecoderWrapper(torch.nn.Module):
|
|
255
277
|
|
256
278
|
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
257
279
|
|
258
|
-
return lm_logits, self_kv_cache
|
280
|
+
return lm_logits, self_kv_cache, batch_position
|
259
281
|
|
260
282
|
|
261
283
|
class _T5Attention(T5Attention):
|
@@ -269,10 +291,10 @@ class _T5Attention(T5Attention):
|
|
269
291
|
position_bias: torch.Tensor = None,
|
270
292
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
271
293
|
cache_position: Optional[torch.Tensor] = None, # 현재 cache sequence 길이
|
294
|
+
batch_index: torch.Tensor = None,
|
272
295
|
is_self_attn: Optional[bool] = None,
|
273
296
|
) -> Tuple[torch.Tensor]:
|
274
297
|
batch_size = hidden_states.shape[0]
|
275
|
-
cross_batch_size = key_value_states.shape[0] if not is_self_attn and cache_position == 0 else None
|
276
298
|
|
277
299
|
def shape(states, batch_size):
|
278
300
|
"""projection"""
|
@@ -288,39 +310,72 @@ class _T5Attention(T5Attention):
|
|
288
310
|
if is_self_attn:
|
289
311
|
key_states = shape(self.k(hidden_states), batch_size)
|
290
312
|
value_states = shape(self.v(hidden_states), batch_size)
|
291
|
-
if past_key_value is not None:
|
292
|
-
# decoder self attn
|
293
|
-
cache_k = past_key_value[0].slice_scatter(
|
294
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
295
|
-
)
|
296
|
-
cache_v = past_key_value[1].slice_scatter(
|
297
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
298
|
-
)
|
299
|
-
past_key_value = (cache_k, cache_v)
|
300
|
-
key_states, value_states = past_key_value
|
301
|
-
|
302
313
|
else:
|
303
314
|
# cross-attn
|
304
|
-
if cache_position == 0:
|
305
|
-
key_states = shape(self.k(key_value_states),
|
306
|
-
value_states = shape(self.v(key_value_states),
|
315
|
+
if cache_position.dim() == 0 :
|
316
|
+
key_states = shape(self.k(key_value_states), key_value_states.shape[0])
|
317
|
+
value_states = shape(self.v(key_value_states), key_value_states.shape[0])
|
307
318
|
past_key_value = key_states, value_states
|
308
319
|
else:
|
309
320
|
key_states = past_key_value[0]
|
310
321
|
value_states = past_key_value[1]
|
311
322
|
|
312
|
-
|
313
|
-
|
314
|
-
|
323
|
+
if (batch_index is None or batch_index == -1) and batch_size > 1:
|
324
|
+
all_key_states = []
|
325
|
+
all_value_states = []
|
326
|
+
all_attn_output = []
|
327
|
+
|
328
|
+
for b in range(batch_size):
|
329
|
+
batch_query_states = query_states[b].unsqueeze(0)
|
330
|
+
batch_key_states = key_states[b].unsqueeze(0)
|
331
|
+
batch_value_states = value_states[b].unsqueeze(0)
|
332
|
+
|
333
|
+
if is_self_attn and past_key_value is not None:
|
334
|
+
batch_key_states = past_key_value[0][b].unsqueeze(0).slice_scatter(
|
335
|
+
batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
336
|
+
)
|
337
|
+
batch_value_states = past_key_value[1][b].unsqueeze(0).slice_scatter(
|
338
|
+
batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
339
|
+
)
|
340
|
+
|
341
|
+
scores = torch.matmul(batch_query_states, batch_key_states.transpose(3, 2))
|
342
|
+
scores += position_bias[b]
|
343
|
+
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
|
344
|
+
scores
|
345
|
+
)
|
346
|
+
attn_output = unshape(torch.matmul(attn_weights, batch_value_states), 1)
|
347
|
+
all_key_states.append(batch_key_states)
|
348
|
+
all_value_states.append(batch_value_states)
|
349
|
+
all_attn_output.append(attn_output)
|
350
|
+
|
351
|
+
key_states = torch.cat(all_key_states, dim=0)
|
352
|
+
value_states = torch.cat(all_value_states, dim=0)
|
353
|
+
attn_output = torch.cat(all_attn_output, dim=0)
|
315
354
|
|
316
|
-
|
317
|
-
|
318
|
-
|
355
|
+
else:
|
356
|
+
if batch_index is None or batch_index == -1:
|
357
|
+
batch_index = 0
|
319
358
|
|
320
|
-
|
321
|
-
|
359
|
+
if is_self_attn and past_key_value is not None:
|
360
|
+
key_states = past_key_value[0].slice_scatter(
|
361
|
+
key_states, dim=2, start=cache_position, end=cache_position + 1
|
362
|
+
)
|
363
|
+
value_states = past_key_value[1].slice_scatter(
|
364
|
+
value_states, dim=2, start=cache_position, end=cache_position + 1
|
365
|
+
)
|
366
|
+
# compute scores
|
367
|
+
scores = torch.matmul(query_states, key_states.transpose(3, 2))
|
368
|
+
scores += position_bias
|
369
|
+
|
370
|
+
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
|
371
|
+
scores
|
372
|
+
) # (batch_size, n_heads, seq_length, key_length)
|
373
|
+
|
374
|
+
attn_output = unshape(torch.matmul(attn_weights, value_states), batch_size) # (batch_size, seq_length, dim)
|
322
375
|
|
323
|
-
|
376
|
+
attn_output = self.o(attn_output)
|
377
|
+
present_key_value = (key_states, value_states)
|
378
|
+
outputs = (attn_output,) + (present_key_value,)
|
324
379
|
return outputs
|
325
380
|
|
326
381
|
|
@@ -331,6 +386,7 @@ class _T5LayerSelfAttention(T5LayerSelfAttention):
|
|
331
386
|
position_bias: torch.Tensor = None,
|
332
387
|
past_key_value: Tuple[torch.Tensor] = None,
|
333
388
|
cache_position: Optional[torch.Tensor] = None,
|
389
|
+
batch_index: torch.Tensor = None,
|
334
390
|
):
|
335
391
|
normed_hidden_states = self.layer_norm(hidden_states)
|
336
392
|
attention_output = _T5Attention.forward(
|
@@ -339,6 +395,7 @@ class _T5LayerSelfAttention(T5LayerSelfAttention):
|
|
339
395
|
position_bias=position_bias,
|
340
396
|
past_key_value=past_key_value,
|
341
397
|
cache_position=cache_position,
|
398
|
+
batch_index=batch_index,
|
342
399
|
is_self_attn=True,
|
343
400
|
)
|
344
401
|
|
@@ -356,6 +413,7 @@ class _T5LayerCrossAttention(T5LayerCrossAttention):
|
|
356
413
|
position_bias: torch.Tensor = None,
|
357
414
|
past_key_value: Tuple[torch.Tensor] = None,
|
358
415
|
cache_position: Optional[torch.Tensor] = None,
|
416
|
+
batch_index: torch.Tensor = None,
|
359
417
|
):
|
360
418
|
normed_hidden_states = self.layer_norm(hidden_states)
|
361
419
|
attention_output = _T5Attention.forward(
|
@@ -365,6 +423,7 @@ class _T5LayerCrossAttention(T5LayerCrossAttention):
|
|
365
423
|
position_bias=position_bias,
|
366
424
|
past_key_value=past_key_value,
|
367
425
|
cache_position=cache_position,
|
426
|
+
batch_index=batch_index,
|
368
427
|
is_self_attn=False,
|
369
428
|
)
|
370
429
|
|
@@ -383,6 +442,7 @@ class _T5Block(T5Block):
|
|
383
442
|
encoder_decoder_position_bias=None,
|
384
443
|
past_key_value=None,
|
385
444
|
cache_position=None,
|
445
|
+
batch_ids=None,
|
386
446
|
):
|
387
447
|
if past_key_value is not None:
|
388
448
|
if not self.is_decoder:
|
@@ -403,13 +463,13 @@ class _T5Block(T5Block):
|
|
403
463
|
cross_attn_past_key_value = past_key_value[2:]
|
404
464
|
else:
|
405
465
|
self_attn_past_key_value, cross_attn_past_key_value = None, None
|
406
|
-
|
407
466
|
self_attention_outputs = _T5LayerSelfAttention.forward(
|
408
467
|
self.layer[0],
|
409
468
|
hidden_states=hidden_states,
|
410
469
|
position_bias=position_bias,
|
411
470
|
past_key_value=self_attn_past_key_value,
|
412
471
|
cache_position=cache_position,
|
472
|
+
batch_index=batch_ids,
|
413
473
|
)
|
414
474
|
|
415
475
|
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
@@ -423,6 +483,7 @@ class _T5Block(T5Block):
|
|
423
483
|
position_bias=encoder_decoder_position_bias,
|
424
484
|
past_key_value=cross_attn_past_key_value,
|
425
485
|
cache_position=cache_position,
|
486
|
+
batch_index=batch_ids,
|
426
487
|
)
|
427
488
|
hidden_states = cross_attention_outputs[0]
|
428
489
|
# Combine self attn and cross attn key value states
|
@@ -22,14 +22,14 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from typing import TYPE_CHECKING,
|
25
|
+
from typing import TYPE_CHECKING, Any, Dict, Union
|
26
26
|
|
27
27
|
import torch
|
28
28
|
from transformers import AutoModelForMaskedLM, PretrainedConfig, Wav2Vec2ForCTC
|
29
29
|
from transformers.modeling_outputs import CausalLMOutput
|
30
30
|
|
31
31
|
from ....modeling_base import RBLNModel
|
32
|
-
from ....modeling_config import
|
32
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
33
33
|
|
34
34
|
|
35
35
|
logger = logging.getLogger(__name__)
|
@@ -78,10 +78,10 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
78
78
|
cls,
|
79
79
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
80
80
|
model_config: "PretrainedConfig",
|
81
|
-
|
82
|
-
rbln_batch_size: Optional[int] = None,
|
81
|
+
rbln_kwargs: Dict[str, Any] = {},
|
83
82
|
) -> RBLNConfig:
|
84
|
-
|
83
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
84
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
85
85
|
|
86
86
|
if rbln_max_seq_len is None:
|
87
87
|
for tokenizer in preprocessors:
|
@@ -91,8 +91,6 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
91
91
|
if rbln_max_seq_len is None:
|
92
92
|
raise ValueError("`rbln_max_seq_len` should be specified!")
|
93
93
|
|
94
|
-
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
95
|
-
|
96
94
|
if rbln_batch_size is None:
|
97
95
|
rbln_batch_size = 1
|
98
96
|
|
@@ -107,11 +105,19 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
107
105
|
),
|
108
106
|
]
|
109
107
|
|
110
|
-
|
108
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
109
|
+
|
110
|
+
rbln_config = RBLNConfig(
|
111
|
+
rbln_cls=cls.__name__,
|
112
|
+
compile_cfgs=[rbln_compile_config],
|
113
|
+
rbln_kwargs=rbln_kwargs,
|
114
|
+
)
|
111
115
|
|
112
|
-
rbln_config
|
113
|
-
|
114
|
-
|
116
|
+
rbln_config.model_cfg.update(
|
117
|
+
{
|
118
|
+
"max_seq_len": rbln_max_seq_len,
|
119
|
+
"batch_size": rbln_batch_size,
|
120
|
+
}
|
115
121
|
)
|
116
122
|
|
117
123
|
return rbln_config
|
@@ -0,0 +1,68 @@
|
|
1
|
+
import torch
|
2
|
+
from transformers import GenerationMixin
|
3
|
+
from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
|
4
|
+
|
5
|
+
|
6
|
+
class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
|
7
|
+
"""
|
8
|
+
This class is based on transformers version 4.44.2.
|
9
|
+
It uses the same generate() method, so it's crucial to maintain the inheritance order.
|
10
|
+
Ensure WhisperGenerationMixin is listed before GenerationMixin.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def _postprocess_outputs(
|
14
|
+
self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, *args, **kwargs
|
15
|
+
):
|
16
|
+
# remove all previously passed decoder input ids
|
17
|
+
|
18
|
+
################################## rbln_change for 4.40.2###################################
|
19
|
+
# 4.40.2 has no keyword shortform, it has seperate codes from generation_fallback
|
20
|
+
is_shortform = kwargs.get("is_shortform", False)
|
21
|
+
start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
|
22
|
+
|
23
|
+
if isinstance(seek_outputs, torch.Tensor):
|
24
|
+
seek_outputs = seek_outputs[:, start_idx:]
|
25
|
+
return seek_outputs, seek_outputs
|
26
|
+
|
27
|
+
############## rbln validation#############
|
28
|
+
if return_token_timestamps and not self.rbln_token_timestamps:
|
29
|
+
raise RuntimeError(
|
30
|
+
"To use .generate() with return_token_timestamps=True, the model must be compiled with rbln_token_timestamps=True. "
|
31
|
+
"You can compile the model by calling .from_pretrained() with export=True and rbln_token_timestamps=True as keyword arguments, "
|
32
|
+
"or you can generate with return_token_timestamps=False."
|
33
|
+
)
|
34
|
+
|
35
|
+
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
36
|
+
num_frames = getattr(generation_config, "num_frames", None)
|
37
|
+
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
38
|
+
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
|
39
|
+
)
|
40
|
+
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
|
41
|
+
|
42
|
+
seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
|
43
|
+
|
44
|
+
def split_by_batch_index(values, key, batch_idx):
|
45
|
+
if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
|
46
|
+
return [v[batch_idx].cpu() for v in values]
|
47
|
+
if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
|
48
|
+
return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
|
49
|
+
elif key == "past_key_values":
|
50
|
+
# we don't save `past_key_values in rbln
|
51
|
+
return None
|
52
|
+
|
53
|
+
return values[batch_idx].cpu()
|
54
|
+
|
55
|
+
sequence_tokens = seek_outputs["sequences"]
|
56
|
+
|
57
|
+
##################################### thkim change #############################################
|
58
|
+
valid_seekoutputs = []
|
59
|
+
for k, v in seek_outputs.items():
|
60
|
+
if v is not None and len(v) > 0 and v[0] is not None:
|
61
|
+
valid_seekoutputs.append((k, v))
|
62
|
+
seek_outputs = [
|
63
|
+
{k: split_by_batch_index(v, k, i) for k, v in valid_seekoutputs}
|
64
|
+
# {k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
|
65
|
+
for i in range(sequence_tokens.shape[0])
|
66
|
+
]
|
67
|
+
|
68
|
+
return sequence_tokens, seek_outputs
|