optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +47 -9
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
- optimum/rbln/diffusers/models/controlnet.py +53 -43
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
- optimum/rbln/modeling_alias.py +6 -11
- optimum/rbln/modeling_base.py +467 -261
- optimum/rbln/modeling_config.py +199 -73
- optimum/rbln/transformers/__init__.py +43 -1
- optimum/rbln/transformers/models/__init__.py +23 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
- optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +50 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +43 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
- optimum_rbln-0.1.12.dist-info/RECORD +103 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -43,12 +43,19 @@ if TYPE_CHECKING:
|
|
43
43
|
from transformers import T5ForConditionalGeneration
|
44
44
|
|
45
45
|
|
46
|
+
class T5Wrapper:
|
47
|
+
def __init__(self, model):
|
48
|
+
self.encoder = T5EncoderWrapper(model)
|
49
|
+
self.decoder = T5DecoderWrapper(model)
|
50
|
+
|
51
|
+
|
46
52
|
class T5Encoder(T5Stack):
|
47
53
|
def forward(
|
48
54
|
self,
|
49
55
|
input_ids: torch.Tensor,
|
50
56
|
attention_mask: torch.Tensor,
|
51
57
|
position_bias: torch.Tensor,
|
58
|
+
batch_ids: torch.Tensor = None,
|
52
59
|
) -> BaseModelOutput:
|
53
60
|
hidden_states = self.embed_tokens(input_ids)
|
54
61
|
extended_attention_mask = self.invert_attention_mask(attention_mask)
|
@@ -58,6 +65,7 @@ class T5Encoder(T5Stack):
|
|
58
65
|
layer_module,
|
59
66
|
hidden_states,
|
60
67
|
position_bias=position_bias,
|
68
|
+
batch_ids=batch_ids,
|
61
69
|
)
|
62
70
|
hidden_states = layer_outputs[0]
|
63
71
|
hidden_states = self.final_layer_norm(hidden_states)
|
@@ -75,6 +83,7 @@ class T5Decoder(T5Stack):
|
|
75
83
|
position_bias: torch.Tensor,
|
76
84
|
encoder_decoder_position_bias: torch.Tensor,
|
77
85
|
cache_position: torch.Tensor,
|
86
|
+
batch_ids: torch.Tensor,
|
78
87
|
) -> BaseModelOutputWithPastAndCrossAttentions:
|
79
88
|
hidden_states = self.embed_tokens(input_ids)
|
80
89
|
extended_attention_mask = self.invert_attention_mask(attention_mask)
|
@@ -84,6 +93,7 @@ class T5Decoder(T5Stack):
|
|
84
93
|
encoder_decoder_position_bias = encoder_decoder_position_bias + encoder_extended_attention_mask
|
85
94
|
|
86
95
|
present_key_value_states = ()
|
96
|
+
|
87
97
|
for layer_module, past_key_value in zip(self.block, past_key_values):
|
88
98
|
layer_outputs = _T5Block.forward(
|
89
99
|
layer_module,
|
@@ -93,6 +103,7 @@ class T5Decoder(T5Stack):
|
|
93
103
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
94
104
|
past_key_value=past_key_value,
|
95
105
|
cache_position=cache_position,
|
106
|
+
batch_ids=batch_ids,
|
96
107
|
)
|
97
108
|
hidden_states, present_key_value_state = layer_outputs[:2]
|
98
109
|
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
@@ -117,17 +128,26 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
117
128
|
)
|
118
129
|
self.encoder_max_length = None
|
119
130
|
self.decoder_max_length = None
|
120
|
-
self.decoder_batch_size = 1
|
121
131
|
|
122
|
-
def forward(
|
123
|
-
|
124
|
-
|
132
|
+
def forward(
|
133
|
+
self,
|
134
|
+
input_ids: torch.Tensor,
|
135
|
+
attention_mask: torch.Tensor,
|
136
|
+
cross_key_value: torch.Tensor = None,
|
137
|
+
batch_idx: torch.Tensor = None,
|
138
|
+
) -> torch.Tensor:
|
125
139
|
decoder_max_length = self.decoder_max_length or self.default_max_length
|
126
140
|
encoder_max_length = self.encoder_max_length or self.default_max_length
|
127
141
|
|
128
142
|
attn_layer = self.encoder.block[0].layer[0].SelfAttention
|
129
143
|
encoder_position_bias = T5Attention.compute_bias(attn_layer, encoder_max_length, encoder_max_length)
|
130
|
-
encoder_outputs = T5Encoder.forward(
|
144
|
+
encoder_outputs = T5Encoder.forward(
|
145
|
+
self.encoder,
|
146
|
+
input_ids,
|
147
|
+
attention_mask,
|
148
|
+
encoder_position_bias,
|
149
|
+
batch_ids=torch.tensor(0, dtype=torch.int32),
|
150
|
+
)
|
131
151
|
|
132
152
|
attn_layer = self.decoder.block[0].layer[0].SelfAttention
|
133
153
|
decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
|
@@ -138,22 +158,14 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
138
158
|
|
139
159
|
dummy_past_key_value = []
|
140
160
|
for i in range(self.config.num_layers):
|
141
|
-
pkv_self_attn_key = torch.zeros(
|
142
|
-
|
143
|
-
)
|
144
|
-
|
145
|
-
decoder_batch_size, self.config.num_heads, decoder_max_length, self.config.d_kv
|
146
|
-
)
|
147
|
-
pkv_cross_attn_key = torch.zeros(
|
148
|
-
encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
|
149
|
-
)
|
150
|
-
pkv_cross_attn_value = torch.zeros(
|
151
|
-
encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
|
152
|
-
)
|
161
|
+
pkv_self_attn_key = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
|
162
|
+
pkv_self_attn_value = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
|
163
|
+
pkv_cross_attn_key = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
|
164
|
+
pkv_cross_attn_value = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
|
153
165
|
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
154
166
|
dummy_past_key_value.append(layer_pkv)
|
155
167
|
|
156
|
-
decoder_attention_mask = torch.zeros(
|
168
|
+
decoder_attention_mask = torch.zeros(1, decoder_max_length, dtype=torch.float32)
|
157
169
|
decoder_attention_mask[:, :1] = 1
|
158
170
|
|
159
171
|
# Since first step of decoder has different graph to further step of it,
|
@@ -161,7 +173,7 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
161
173
|
# TODO(jongho): Separate first-step-decoder.
|
162
174
|
decoder_outputs = T5Decoder.forward(
|
163
175
|
self.decoder,
|
164
|
-
input_ids=torch.zeros(
|
176
|
+
input_ids=torch.zeros(1, 1, dtype=torch.int64),
|
165
177
|
attention_mask=decoder_attention_mask,
|
166
178
|
position_bias=decoder_position_bias,
|
167
179
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
@@ -169,6 +181,7 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
169
181
|
encoder_attention_mask=attention_mask,
|
170
182
|
past_key_values=dummy_past_key_value,
|
171
183
|
cache_position=torch.tensor(0, dtype=torch.int32),
|
184
|
+
batch_ids=torch.tensor(0, dtype=torch.int32),
|
172
185
|
)
|
173
186
|
|
174
187
|
past_key_values = decoder_outputs.past_key_values
|
@@ -179,7 +192,9 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
179
192
|
cross_kv_cache.append(past_key_values[i][3])
|
180
193
|
cross_kv_cache = torch.stack(cross_kv_cache, dim=0)
|
181
194
|
|
182
|
-
|
195
|
+
cross_key_value = cross_key_value.slice_scatter(cross_kv_cache, dim=1, start=batch_idx, end=batch_idx + 1)
|
196
|
+
|
197
|
+
return cross_key_value
|
183
198
|
|
184
199
|
|
185
200
|
class T5DecoderWrapper(torch.nn.Module):
|
@@ -201,6 +216,7 @@ class T5DecoderWrapper(torch.nn.Module):
|
|
201
216
|
attention_mask: torch.Tensor,
|
202
217
|
encoder_attention_mask: torch.Tensor,
|
203
218
|
cache_position: torch.Tensor,
|
219
|
+
batch_position: torch.Tensor,
|
204
220
|
self_kv_cache: torch.Tensor,
|
205
221
|
cross_kv_cache: torch.Tensor,
|
206
222
|
) -> Tuple[torch.Tensor]:
|
@@ -210,6 +226,11 @@ class T5DecoderWrapper(torch.nn.Module):
|
|
210
226
|
encoder_max_length = self.encoder_max_length or self.default_max_length
|
211
227
|
decoder_max_length = self.decoder_max_length or self.default_max_length
|
212
228
|
|
229
|
+
if input_ids.shape[1] == 1:
|
230
|
+
rbln_batch_position = None
|
231
|
+
else:
|
232
|
+
rbln_batch_position = batch_position
|
233
|
+
|
213
234
|
kv_cache = ()
|
214
235
|
for i in range(0, num_layers * 2, 2):
|
215
236
|
kv_cache = kv_cache + (
|
@@ -223,7 +244,13 @@ class T5DecoderWrapper(torch.nn.Module):
|
|
223
244
|
|
224
245
|
attn_layer = self.model.decoder.block[0].layer[0].SelfAttention
|
225
246
|
_decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
|
226
|
-
|
247
|
+
|
248
|
+
# position_bias need to compute with batch (for cb)
|
249
|
+
batch_decoder_position_bias = []
|
250
|
+
for i in range(input_ids.shape[0]):
|
251
|
+
batch_position_bias = _decoder_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
|
252
|
+
batch_decoder_position_bias.append(batch_position_bias)
|
253
|
+
decoder_position_bias = torch.cat(batch_decoder_position_bias, dim=0)
|
227
254
|
|
228
255
|
attn_layer = self.model.decoder.block[0].layer[1].EncDecAttention
|
229
256
|
encoder_decoder_position_bias = torch.zeros(1, attn_layer.n_heads, 1, encoder_max_length)
|
@@ -238,6 +265,7 @@ class T5DecoderWrapper(torch.nn.Module):
|
|
238
265
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
239
266
|
past_key_values=kv_cache,
|
240
267
|
cache_position=cache_position,
|
268
|
+
batch_ids=rbln_batch_position,
|
241
269
|
)
|
242
270
|
|
243
271
|
past_key_values = decoder_outputs.past_key_values
|
@@ -255,7 +283,7 @@ class T5DecoderWrapper(torch.nn.Module):
|
|
255
283
|
|
256
284
|
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
257
285
|
|
258
|
-
return lm_logits, self_kv_cache
|
286
|
+
return lm_logits, self_kv_cache, batch_position
|
259
287
|
|
260
288
|
|
261
289
|
class _T5Attention(T5Attention):
|
@@ -269,10 +297,10 @@ class _T5Attention(T5Attention):
|
|
269
297
|
position_bias: torch.Tensor = None,
|
270
298
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
271
299
|
cache_position: Optional[torch.Tensor] = None, # 현재 cache sequence 길이
|
300
|
+
batch_index: torch.Tensor = None,
|
272
301
|
is_self_attn: Optional[bool] = None,
|
273
302
|
) -> Tuple[torch.Tensor]:
|
274
303
|
batch_size = hidden_states.shape[0]
|
275
|
-
cross_batch_size = key_value_states.shape[0] if not is_self_attn and cache_position == 0 else None
|
276
304
|
|
277
305
|
def shape(states, batch_size):
|
278
306
|
"""projection"""
|
@@ -288,39 +316,80 @@ class _T5Attention(T5Attention):
|
|
288
316
|
if is_self_attn:
|
289
317
|
key_states = shape(self.k(hidden_states), batch_size)
|
290
318
|
value_states = shape(self.v(hidden_states), batch_size)
|
291
|
-
if past_key_value is not None:
|
292
|
-
# decoder self attn
|
293
|
-
cache_k = past_key_value[0].slice_scatter(
|
294
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
295
|
-
)
|
296
|
-
cache_v = past_key_value[1].slice_scatter(
|
297
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
298
|
-
)
|
299
|
-
past_key_value = (cache_k, cache_v)
|
300
|
-
key_states, value_states = past_key_value
|
301
|
-
|
302
319
|
else:
|
303
320
|
# cross-attn
|
304
|
-
if cache_position == 0:
|
305
|
-
key_states = shape(self.k(key_value_states),
|
306
|
-
value_states = shape(self.v(key_value_states),
|
321
|
+
if cache_position.dim() == 0:
|
322
|
+
key_states = shape(self.k(key_value_states), key_value_states.shape[0])
|
323
|
+
value_states = shape(self.v(key_value_states), key_value_states.shape[0])
|
307
324
|
past_key_value = key_states, value_states
|
308
325
|
else:
|
309
326
|
key_states = past_key_value[0]
|
310
327
|
value_states = past_key_value[1]
|
311
328
|
|
312
|
-
|
313
|
-
|
314
|
-
|
329
|
+
if (batch_index is None or batch_index == -1) and batch_size > 1:
|
330
|
+
all_key_states = []
|
331
|
+
all_value_states = []
|
332
|
+
all_attn_output = []
|
333
|
+
|
334
|
+
for b in range(batch_size):
|
335
|
+
batch_query_states = query_states[b].unsqueeze(0)
|
336
|
+
batch_key_states = key_states[b].unsqueeze(0)
|
337
|
+
batch_value_states = value_states[b].unsqueeze(0)
|
338
|
+
|
339
|
+
if is_self_attn and past_key_value is not None:
|
340
|
+
batch_key_states = (
|
341
|
+
past_key_value[0][b]
|
342
|
+
.unsqueeze(0)
|
343
|
+
.slice_scatter(
|
344
|
+
batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
345
|
+
)
|
346
|
+
)
|
347
|
+
batch_value_states = (
|
348
|
+
past_key_value[1][b]
|
349
|
+
.unsqueeze(0)
|
350
|
+
.slice_scatter(
|
351
|
+
batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
352
|
+
)
|
353
|
+
)
|
354
|
+
|
355
|
+
scores = torch.matmul(batch_query_states, batch_key_states.transpose(3, 2))
|
356
|
+
scores += position_bias[b]
|
357
|
+
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
358
|
+
attn_output = unshape(torch.matmul(attn_weights, batch_value_states), 1)
|
359
|
+
all_key_states.append(batch_key_states)
|
360
|
+
all_value_states.append(batch_value_states)
|
361
|
+
all_attn_output.append(attn_output)
|
362
|
+
|
363
|
+
key_states = torch.cat(all_key_states, dim=0)
|
364
|
+
value_states = torch.cat(all_value_states, dim=0)
|
365
|
+
attn_output = torch.cat(all_attn_output, dim=0)
|
315
366
|
|
316
|
-
|
317
|
-
|
318
|
-
|
367
|
+
else:
|
368
|
+
if batch_index is None or batch_index == -1:
|
369
|
+
batch_index = 0
|
319
370
|
|
320
|
-
|
321
|
-
|
371
|
+
if is_self_attn and past_key_value is not None:
|
372
|
+
key_states = past_key_value[0].slice_scatter(
|
373
|
+
key_states, dim=2, start=cache_position, end=cache_position + 1
|
374
|
+
)
|
375
|
+
value_states = past_key_value[1].slice_scatter(
|
376
|
+
value_states, dim=2, start=cache_position, end=cache_position + 1
|
377
|
+
)
|
378
|
+
# compute scores
|
379
|
+
scores = torch.matmul(query_states, key_states.transpose(3, 2))
|
380
|
+
scores += position_bias
|
381
|
+
|
382
|
+
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
|
383
|
+
scores
|
384
|
+
) # (batch_size, n_heads, seq_length, key_length)
|
322
385
|
|
323
|
-
|
386
|
+
attn_output = unshape(
|
387
|
+
torch.matmul(attn_weights, value_states), batch_size
|
388
|
+
) # (batch_size, seq_length, dim)
|
389
|
+
|
390
|
+
attn_output = self.o(attn_output)
|
391
|
+
present_key_value = (key_states, value_states)
|
392
|
+
outputs = (attn_output,) + (present_key_value,)
|
324
393
|
return outputs
|
325
394
|
|
326
395
|
|
@@ -331,6 +400,7 @@ class _T5LayerSelfAttention(T5LayerSelfAttention):
|
|
331
400
|
position_bias: torch.Tensor = None,
|
332
401
|
past_key_value: Tuple[torch.Tensor] = None,
|
333
402
|
cache_position: Optional[torch.Tensor] = None,
|
403
|
+
batch_index: torch.Tensor = None,
|
334
404
|
):
|
335
405
|
normed_hidden_states = self.layer_norm(hidden_states)
|
336
406
|
attention_output = _T5Attention.forward(
|
@@ -339,6 +409,7 @@ class _T5LayerSelfAttention(T5LayerSelfAttention):
|
|
339
409
|
position_bias=position_bias,
|
340
410
|
past_key_value=past_key_value,
|
341
411
|
cache_position=cache_position,
|
412
|
+
batch_index=batch_index,
|
342
413
|
is_self_attn=True,
|
343
414
|
)
|
344
415
|
|
@@ -356,6 +427,7 @@ class _T5LayerCrossAttention(T5LayerCrossAttention):
|
|
356
427
|
position_bias: torch.Tensor = None,
|
357
428
|
past_key_value: Tuple[torch.Tensor] = None,
|
358
429
|
cache_position: Optional[torch.Tensor] = None,
|
430
|
+
batch_index: torch.Tensor = None,
|
359
431
|
):
|
360
432
|
normed_hidden_states = self.layer_norm(hidden_states)
|
361
433
|
attention_output = _T5Attention.forward(
|
@@ -365,6 +437,7 @@ class _T5LayerCrossAttention(T5LayerCrossAttention):
|
|
365
437
|
position_bias=position_bias,
|
366
438
|
past_key_value=past_key_value,
|
367
439
|
cache_position=cache_position,
|
440
|
+
batch_index=batch_index,
|
368
441
|
is_self_attn=False,
|
369
442
|
)
|
370
443
|
|
@@ -383,6 +456,7 @@ class _T5Block(T5Block):
|
|
383
456
|
encoder_decoder_position_bias=None,
|
384
457
|
past_key_value=None,
|
385
458
|
cache_position=None,
|
459
|
+
batch_ids=None,
|
386
460
|
):
|
387
461
|
if past_key_value is not None:
|
388
462
|
if not self.is_decoder:
|
@@ -403,13 +477,13 @@ class _T5Block(T5Block):
|
|
403
477
|
cross_attn_past_key_value = past_key_value[2:]
|
404
478
|
else:
|
405
479
|
self_attn_past_key_value, cross_attn_past_key_value = None, None
|
406
|
-
|
407
480
|
self_attention_outputs = _T5LayerSelfAttention.forward(
|
408
481
|
self.layer[0],
|
409
482
|
hidden_states=hidden_states,
|
410
483
|
position_bias=position_bias,
|
411
484
|
past_key_value=self_attn_past_key_value,
|
412
485
|
cache_position=cache_position,
|
486
|
+
batch_index=batch_ids,
|
413
487
|
)
|
414
488
|
|
415
489
|
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
@@ -423,6 +497,7 @@ class _T5Block(T5Block):
|
|
423
497
|
position_bias=encoder_decoder_position_bias,
|
424
498
|
past_key_value=cross_attn_past_key_value,
|
425
499
|
cache_position=cache_position,
|
500
|
+
batch_index=batch_ids,
|
426
501
|
)
|
427
502
|
hidden_states = cross_attention_outputs[0]
|
428
503
|
# Combine self attn and cross attn key value states
|
@@ -22,14 +22,14 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from typing import TYPE_CHECKING,
|
25
|
+
from typing import TYPE_CHECKING, Any, Dict, Union
|
26
26
|
|
27
27
|
import torch
|
28
28
|
from transformers import AutoModelForMaskedLM, PretrainedConfig, Wav2Vec2ForCTC
|
29
29
|
from transformers.modeling_outputs import CausalLMOutput
|
30
30
|
|
31
31
|
from ....modeling_base import RBLNModel
|
32
|
-
from ....modeling_config import
|
32
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
33
33
|
|
34
34
|
|
35
35
|
logger = logging.getLogger(__name__)
|
@@ -65,7 +65,6 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
65
65
|
- compiling the resulting graph using the RBLN compiler.
|
66
66
|
"""
|
67
67
|
|
68
|
-
model_type = "rbln_model"
|
69
68
|
main_input_name = "input_values"
|
70
69
|
auto_model_class = AutoModelForMaskedLM
|
71
70
|
|
@@ -78,10 +77,10 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
78
77
|
cls,
|
79
78
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
80
79
|
model_config: "PretrainedConfig",
|
81
|
-
|
82
|
-
rbln_batch_size: Optional[int] = None,
|
80
|
+
rbln_kwargs: Dict[str, Any] = {},
|
83
81
|
) -> RBLNConfig:
|
84
|
-
|
82
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
83
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
85
84
|
|
86
85
|
if rbln_max_seq_len is None:
|
87
86
|
for tokenizer in preprocessors:
|
@@ -91,8 +90,6 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
91
90
|
if rbln_max_seq_len is None:
|
92
91
|
raise ValueError("`rbln_max_seq_len` should be specified!")
|
93
92
|
|
94
|
-
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
95
|
-
|
96
93
|
if rbln_batch_size is None:
|
97
94
|
rbln_batch_size = 1
|
98
95
|
|
@@ -107,11 +104,19 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
107
104
|
),
|
108
105
|
]
|
109
106
|
|
110
|
-
|
107
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
108
|
+
|
109
|
+
rbln_config = RBLNConfig(
|
110
|
+
rbln_cls=cls.__name__,
|
111
|
+
compile_cfgs=[rbln_compile_config],
|
112
|
+
rbln_kwargs=rbln_kwargs,
|
113
|
+
)
|
111
114
|
|
112
|
-
rbln_config
|
113
|
-
|
114
|
-
|
115
|
+
rbln_config.model_cfg.update(
|
116
|
+
{
|
117
|
+
"max_seq_len": rbln_max_seq_len,
|
118
|
+
"batch_size": rbln_batch_size,
|
119
|
+
}
|
115
120
|
)
|
116
121
|
|
117
122
|
return rbln_config
|
@@ -0,0 +1,68 @@
|
|
1
|
+
import torch
|
2
|
+
from transformers import GenerationMixin
|
3
|
+
from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
|
4
|
+
|
5
|
+
|
6
|
+
class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
|
7
|
+
"""
|
8
|
+
This class is based on transformers version 4.44.2.
|
9
|
+
It uses the same generate() method, so it's crucial to maintain the inheritance order.
|
10
|
+
Ensure WhisperGenerationMixin is listed before GenerationMixin.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def _postprocess_outputs(
|
14
|
+
self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, *args, **kwargs
|
15
|
+
):
|
16
|
+
# remove all previously passed decoder input ids
|
17
|
+
|
18
|
+
################################## rbln_change for 4.40.2###################################
|
19
|
+
# 4.40.2 has no keyword shortform, it has seperate codes from generation_fallback
|
20
|
+
is_shortform = kwargs.get("is_shortform", False)
|
21
|
+
start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
|
22
|
+
|
23
|
+
if isinstance(seek_outputs, torch.Tensor):
|
24
|
+
seek_outputs = seek_outputs[:, start_idx:]
|
25
|
+
return seek_outputs, seek_outputs
|
26
|
+
|
27
|
+
############## rbln validation#############
|
28
|
+
if return_token_timestamps and not self.rbln_token_timestamps:
|
29
|
+
raise RuntimeError(
|
30
|
+
"To use .generate() with return_token_timestamps=True, the model must be compiled with rbln_token_timestamps=True. "
|
31
|
+
"You can compile the model by calling .from_pretrained() with export=True and rbln_token_timestamps=True as keyword arguments, "
|
32
|
+
"or you can generate with return_token_timestamps=False."
|
33
|
+
)
|
34
|
+
|
35
|
+
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
36
|
+
num_frames = getattr(generation_config, "num_frames", None)
|
37
|
+
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
38
|
+
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
|
39
|
+
)
|
40
|
+
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
|
41
|
+
|
42
|
+
seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
|
43
|
+
|
44
|
+
def split_by_batch_index(values, key, batch_idx):
|
45
|
+
if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
|
46
|
+
return [v[batch_idx].cpu() for v in values]
|
47
|
+
if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
|
48
|
+
return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
|
49
|
+
elif key == "past_key_values":
|
50
|
+
# we don't save `past_key_values in rbln
|
51
|
+
return None
|
52
|
+
|
53
|
+
return values[batch_idx].cpu()
|
54
|
+
|
55
|
+
sequence_tokens = seek_outputs["sequences"]
|
56
|
+
|
57
|
+
##################################### thkim change #############################################
|
58
|
+
valid_seekoutputs = []
|
59
|
+
for k, v in seek_outputs.items():
|
60
|
+
if v is not None and len(v) > 0 and v[0] is not None:
|
61
|
+
valid_seekoutputs.append((k, v))
|
62
|
+
seek_outputs = [
|
63
|
+
{k: split_by_batch_index(v, k, i) for k, v in valid_seekoutputs}
|
64
|
+
# {k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
|
65
|
+
for i in range(sequence_tokens.shape[0])
|
66
|
+
]
|
67
|
+
|
68
|
+
return sequence_tokens, seek_outputs
|