optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +2 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/ops/__init__.py +2 -1
- optimum/rbln/ops/attn.py +9 -7
- optimum/rbln/ops/linear.py +25 -0
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +4 -3
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +20 -17
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +14 -14
- optimum/rbln/transformers/models/t5/modeling_t5.py +3 -210
- optimum/rbln/transformers/models/t5/t5_architecture.py +9 -3
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +24 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +422 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +341 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +98 -47
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +71 -26
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4a1.dist-info}/METADATA +5 -5
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4a1.dist-info}/RECORD +21 -17
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4a1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py
ADDED
@@ -0,0 +1,341 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from typing import Optional, Tuple, Union
|
25
|
+
|
26
|
+
import torch
|
27
|
+
from torch import nn
|
28
|
+
from transformers.modeling_attn_mask_utils import (
|
29
|
+
_prepare_4d_causal_attention_mask,
|
30
|
+
)
|
31
|
+
from transformers.modeling_outputs import (
|
32
|
+
BaseModelOutput,
|
33
|
+
Seq2SeqLMOutput,
|
34
|
+
)
|
35
|
+
from transformers.utils import logging
|
36
|
+
|
37
|
+
from ....ops import register_rbln_custom_cache_update, register_rbln_custom_paged_add_softmax_attention
|
38
|
+
|
39
|
+
|
40
|
+
logger = logging.get_logger(__name__)
|
41
|
+
|
42
|
+
|
43
|
+
class TimeSeriesTransformersWrapper:
|
44
|
+
def __init__(self, model, num_parallel_samples):
|
45
|
+
register_rbln_custom_cache_update()
|
46
|
+
register_rbln_custom_paged_add_softmax_attention()
|
47
|
+
self.encoder = TimeSeriesTransformersEncoderWrapper(model)
|
48
|
+
self.decoder = TimeSeriesTransformersDecoderWrapper(model, num_parallel_samples)
|
49
|
+
|
50
|
+
|
51
|
+
class TimeSeriesTransformersEncoderWrapper(torch.nn.Module):
|
52
|
+
def __init__(self, model):
|
53
|
+
super().__init__()
|
54
|
+
self.config = model.config
|
55
|
+
self.encoder = model.get_encoder()
|
56
|
+
self.num_heads = self.config.decoder_attention_heads
|
57
|
+
self.d_kv = self.config.d_model // self.num_heads
|
58
|
+
self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
|
59
|
+
|
60
|
+
def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
|
61
|
+
return (
|
62
|
+
nn.ModuleList(layer.encoder_attn.k_proj for layer in decoder_layers),
|
63
|
+
nn.ModuleList(layer.encoder_attn.v_proj for layer in decoder_layers),
|
64
|
+
)
|
65
|
+
|
66
|
+
def forward(
|
67
|
+
self,
|
68
|
+
inputs_embeds: torch.Tensor,
|
69
|
+
cross_key_values: torch.Tensor, # n_layers, batch_size, num_heads, context_length, d_kv
|
70
|
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
71
|
+
# 1. get encoder last_hidden_states
|
72
|
+
encoder_outputs = self.encoder(inputs_embeds=inputs_embeds, attention_mask=None, return_dict=False)
|
73
|
+
last_hidden_states = encoder_outputs[0]
|
74
|
+
|
75
|
+
# 2. pre-compute cross_attention's past_key_value which used in decoder phase.
|
76
|
+
cross_kv = []
|
77
|
+
batch_size = inputs_embeds.shape[0]
|
78
|
+
for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
|
79
|
+
past_k = k_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
|
80
|
+
past_v = v_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
|
81
|
+
|
82
|
+
cross_kv.append(past_k)
|
83
|
+
cross_kv.append(past_v)
|
84
|
+
|
85
|
+
cross_kv = torch.stack(cross_kv, dim=0)
|
86
|
+
|
87
|
+
# 3. update cross_attention's past_key_value to the device-dram for optimization.
|
88
|
+
bidx = torch.tensor(0, dtype=torch.int16)
|
89
|
+
axis = torch.tensor(1, dtype=torch.int16)
|
90
|
+
enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
|
91
|
+
|
92
|
+
return enc_output
|
93
|
+
|
94
|
+
|
95
|
+
class TimeSeriesTransformersDecoderWrapper(torch.nn.Module):
|
96
|
+
def __init__(self, model, num_parallel_samples):
|
97
|
+
super().__init__()
|
98
|
+
self.config = model.config
|
99
|
+
self.num_layers = self.config.decoder_layers
|
100
|
+
self.decoder = self.convert_to_rbln_tst_decoder(model, num_parallel_samples)
|
101
|
+
self.parameter_projection = model.parameter_projection
|
102
|
+
|
103
|
+
def convert_to_rbln_tst_decoder(self, model: nn.Module, num_parallel_samples: int):
|
104
|
+
new_layers = []
|
105
|
+
for layer in model.get_decoder().layers:
|
106
|
+
self_attn = TimeSeriesTransformersSelfAttention(layer.self_attn, num_parallel_samples)
|
107
|
+
cross_attn = TimeSeriesTransformersCrossAttention(layer.encoder_attn, num_parallel_samples)
|
108
|
+
new_layers.append(TimeSeriesTransformersDecoderLayer(layer, self_attn, cross_attn))
|
109
|
+
|
110
|
+
decoder_model = TimeSeriesTransformersDecoder(model.get_decoder(), new_layers)
|
111
|
+
|
112
|
+
return decoder_model
|
113
|
+
|
114
|
+
def forward(
|
115
|
+
self,
|
116
|
+
inputs_embeds: torch.Tensor,
|
117
|
+
decoder_attention_mask: torch.Tensor,
|
118
|
+
cache_position: torch.Tensor,
|
119
|
+
block_tables: torch.Tensor,
|
120
|
+
cross_kv_cache: torch.Tensor, # batch_size, num_heads, context_length, d_kv
|
121
|
+
*self_kv_cache: torch.Tensor, # batch_size * num_parallel_samples, num_heads, prediction_length, d_kv
|
122
|
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
123
|
+
# prepare past_key_values
|
124
|
+
self_past_key_values = ()
|
125
|
+
cross_past_key_values = ()
|
126
|
+
for i in range(0, self.num_layers * 2, 2):
|
127
|
+
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
128
|
+
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
129
|
+
|
130
|
+
# Decode
|
131
|
+
last_hidden_states = self.decoder(
|
132
|
+
inputs_embeds=inputs_embeds,
|
133
|
+
attention_mask=decoder_attention_mask,
|
134
|
+
cache_position=cache_position,
|
135
|
+
block_tables=block_tables,
|
136
|
+
self_past_key_values=self_past_key_values,
|
137
|
+
cross_past_key_values=cross_past_key_values,
|
138
|
+
)
|
139
|
+
|
140
|
+
params = self.parameter_projection(last_hidden_states[:, -1:])
|
141
|
+
|
142
|
+
outputs = ()
|
143
|
+
outputs += (params,)
|
144
|
+
outputs += (last_hidden_states,)
|
145
|
+
|
146
|
+
return outputs
|
147
|
+
|
148
|
+
|
149
|
+
class TimeSeriesTransformersDecoder(nn.Module):
|
150
|
+
def __init__(self, model, layers, **kwargs):
|
151
|
+
super().__init__()
|
152
|
+
self._original_mod = model
|
153
|
+
self.config = model.config
|
154
|
+
self.layers = nn.ModuleList(layers)
|
155
|
+
self.value_embedding = model.value_embedding
|
156
|
+
self.embed_positions = model.embed_positions
|
157
|
+
self.layernorm_embedding = model.layernorm_embedding
|
158
|
+
|
159
|
+
def forward(
|
160
|
+
self,
|
161
|
+
inputs_embeds: torch.Tensor = None,
|
162
|
+
attention_mask: Optional[torch.Tensor] = None,
|
163
|
+
self_past_key_values: Optional[torch.Tensor] = None,
|
164
|
+
cross_past_key_values: Optional[torch.Tensor] = None,
|
165
|
+
cache_position: Optional[torch.Tensor] = None,
|
166
|
+
block_tables: torch.Tensor = None,
|
167
|
+
):
|
168
|
+
input_shape = inputs_embeds.size()[:-1]
|
169
|
+
|
170
|
+
# prepare casual_attn_mask
|
171
|
+
attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
|
172
|
+
|
173
|
+
hidden_states = self.value_embedding(inputs_embeds)
|
174
|
+
embed_pos = self.embed_positions.weight[cache_position + self.config.context_length]
|
175
|
+
hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
|
176
|
+
|
177
|
+
# iterate decoder_layer
|
178
|
+
for self_past_key_value, cross_past_key_value, decoder_layer in zip(
|
179
|
+
self_past_key_values, cross_past_key_values, self.layers
|
180
|
+
):
|
181
|
+
hidden_states = decoder_layer(
|
182
|
+
hidden_states,
|
183
|
+
attention_mask=attention_mask,
|
184
|
+
# encoder_attention_mask=encoder_attention_mask,
|
185
|
+
self_past_key_value=self_past_key_value,
|
186
|
+
cross_past_key_value=cross_past_key_value,
|
187
|
+
cache_position=cache_position,
|
188
|
+
block_tables=block_tables,
|
189
|
+
)
|
190
|
+
|
191
|
+
return hidden_states
|
192
|
+
|
193
|
+
|
194
|
+
class TimeSeriesTransformersDecoderLayer(nn.Module):
|
195
|
+
def __init__(self, decoder_layer, self_attn, cross_attn):
|
196
|
+
super().__init__()
|
197
|
+
self._original_mod = decoder_layer
|
198
|
+
self.self_attn = self_attn
|
199
|
+
self.encoder_attn = cross_attn
|
200
|
+
self.embed_dim = decoder_layer.embed_dim
|
201
|
+
self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
|
202
|
+
self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
|
203
|
+
self.final_layer_norm = decoder_layer.final_layer_norm
|
204
|
+
self.activation_fn = decoder_layer.activation_fn
|
205
|
+
self.fc1 = decoder_layer.fc1
|
206
|
+
self.fc2 = decoder_layer.fc2
|
207
|
+
|
208
|
+
def forward(
|
209
|
+
self,
|
210
|
+
hidden_states: torch.Tensor,
|
211
|
+
attention_mask: Optional[torch.Tensor] = None,
|
212
|
+
self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
213
|
+
cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
214
|
+
cache_position: Optional[torch.Tensor] = None,
|
215
|
+
block_tables: torch.Tensor = None,
|
216
|
+
) -> torch.Tensor:
|
217
|
+
# Self Attention Block
|
218
|
+
residual = hidden_states
|
219
|
+
hidden_states = self.self_attn(
|
220
|
+
hidden_states=hidden_states,
|
221
|
+
past_key_value=self_past_key_value,
|
222
|
+
attention_mask=attention_mask,
|
223
|
+
cache_position=cache_position,
|
224
|
+
block_tables=block_tables,
|
225
|
+
)
|
226
|
+
hidden_states = residual + hidden_states
|
227
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
228
|
+
|
229
|
+
# Cross-Attention Block
|
230
|
+
residual = hidden_states
|
231
|
+
hidden_states = self.encoder_attn(
|
232
|
+
hidden_states=hidden_states,
|
233
|
+
past_key_value=cross_past_key_value,
|
234
|
+
# attention_mask=encoder_attention_mask,
|
235
|
+
)
|
236
|
+
hidden_states = residual + hidden_states
|
237
|
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
238
|
+
|
239
|
+
# Fully Connected Block
|
240
|
+
residual = hidden_states
|
241
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
242
|
+
hidden_states = self.fc2(hidden_states)
|
243
|
+
hidden_states = residual + hidden_states
|
244
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
245
|
+
|
246
|
+
return hidden_states
|
247
|
+
|
248
|
+
|
249
|
+
class TimeSeriesTransformersAttention(nn.Module):
|
250
|
+
def __init__(self, attn, num_parallel_samples):
|
251
|
+
super().__init__()
|
252
|
+
self._original_mod = attn
|
253
|
+
self.q_proj = attn.q_proj
|
254
|
+
self.k_proj = attn.k_proj
|
255
|
+
self.v_proj = attn.v_proj
|
256
|
+
self.out_proj = attn.out_proj
|
257
|
+
self.num_heads = attn.num_heads
|
258
|
+
self.embed_dim = attn.embed_dim
|
259
|
+
self.head_dim = attn.head_dim
|
260
|
+
self.scaling = attn.scaling
|
261
|
+
self.num_parallel_samples = num_parallel_samples
|
262
|
+
|
263
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
264
|
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
265
|
+
|
266
|
+
|
267
|
+
class TimeSeriesTransformersSelfAttention(TimeSeriesTransformersAttention):
|
268
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
269
|
+
return tensor.view(1, seq_len, 1, bsz * self.num_heads, self.head_dim).transpose(1, 3)
|
270
|
+
|
271
|
+
def forward(
|
272
|
+
self,
|
273
|
+
hidden_states: torch.Tensor,
|
274
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
275
|
+
attention_mask: Optional[torch.Tensor] = None,
|
276
|
+
cache_position: Optional[torch.Tensor] = None,
|
277
|
+
block_tables: Optional[torch.Tensor] = None,
|
278
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
279
|
+
bsz, tgt_len, _ = hidden_states.size()
|
280
|
+
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
281
|
+
query_states = query_states * self.scaling
|
282
|
+
|
283
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
284
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
285
|
+
|
286
|
+
block_size = past_key_value[0].shape[-2]
|
287
|
+
attn_output = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode(
|
288
|
+
query_states,
|
289
|
+
key_states,
|
290
|
+
value_states,
|
291
|
+
attention_mask.unsqueeze(2),
|
292
|
+
past_key_value[0].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
|
293
|
+
past_key_value[1].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
|
294
|
+
cache_position.expand(bsz, 1),
|
295
|
+
torch.tensor(1.0, dtype=torch.float32), # scale
|
296
|
+
block_tables,
|
297
|
+
block_size,
|
298
|
+
)
|
299
|
+
|
300
|
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
301
|
+
attn_output = attn_output.transpose(1, 2)
|
302
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
303
|
+
attn_output = self.out_proj(attn_output)
|
304
|
+
|
305
|
+
return attn_output
|
306
|
+
|
307
|
+
|
308
|
+
class TimeSeriesTransformersCrossAttention(TimeSeriesTransformersSelfAttention):
|
309
|
+
def forward(
|
310
|
+
self,
|
311
|
+
hidden_states: torch.Tensor,
|
312
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
313
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
314
|
+
batch_size, query_len, _ = hidden_states.size()
|
315
|
+
query_states = (
|
316
|
+
self.q_proj(hidden_states)
|
317
|
+
.view(
|
318
|
+
batch_size // self.num_parallel_samples,
|
319
|
+
self.num_parallel_samples,
|
320
|
+
query_len,
|
321
|
+
self.num_heads,
|
322
|
+
self.head_dim,
|
323
|
+
)
|
324
|
+
.transpose(2, 3)
|
325
|
+
)
|
326
|
+
query_states = query_states * self.scaling
|
327
|
+
|
328
|
+
key_states = past_key_value[0].unsqueeze(1)
|
329
|
+
value_states = past_key_value[1].unsqueeze(1)
|
330
|
+
|
331
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(3, 4))
|
332
|
+
attn_weights = attn_weights
|
333
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
334
|
+
|
335
|
+
attn_output = torch.matmul(attn_weights, value_states)
|
336
|
+
attn_output = attn_output.view(batch_size, self.num_heads, query_len, self.head_dim)
|
337
|
+
attn_output = attn_output.transpose(1, 2)
|
338
|
+
attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
|
339
|
+
attn_output = self.out_proj(attn_output)
|
340
|
+
|
341
|
+
return attn_output
|
@@ -38,29 +38,34 @@ from .whisper_architecture import WhisperWrapper
|
|
38
38
|
logger = get_logger(__name__)
|
39
39
|
|
40
40
|
if TYPE_CHECKING:
|
41
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, PretrainedConfig, PreTrainedModel
|
41
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, GenerationConfig, PretrainedConfig, PreTrainedModel
|
42
42
|
|
43
43
|
|
44
44
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
45
45
|
mandatory_members = ["main_input_name"]
|
46
46
|
|
47
|
-
def forward(self,
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
n_pad_to_batch = self.batch_size - input_features.shape[0]
|
52
|
-
if n_pad_to_batch > 0:
|
53
|
-
input_features = torch.nn.functional.pad(input_features, (0, 0, 0, 0, 0, n_pad_to_batch))
|
54
|
-
|
55
|
-
_ = super().forward(input_features=input_features)
|
56
|
-
|
57
|
-
# dummy output for generation
|
58
|
-
return BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
|
47
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
48
|
+
output = super().forward(*args, **kwargs)
|
49
|
+
return BaseModelOutput(last_hidden_state=output)
|
59
50
|
|
60
51
|
|
61
52
|
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
62
53
|
mandatory_members = ["main_input_name"]
|
63
54
|
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
runtime: rebel.Runtime,
|
58
|
+
batch_size: int,
|
59
|
+
dec_max_seq_len: int,
|
60
|
+
use_attention_mask: Optional[bool] = None,
|
61
|
+
**kwargs: Any,
|
62
|
+
) -> None:
|
63
|
+
super().__init__(runtime, **kwargs)
|
64
|
+
self.batch_size = batch_size
|
65
|
+
self.dec_max_seq_len = dec_max_seq_len
|
66
|
+
self.use_attention_mask = use_attention_mask
|
67
|
+
self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
|
68
|
+
|
64
69
|
def forward(
|
65
70
|
self,
|
66
71
|
decoder_input_ids: torch.Tensor = None,
|
@@ -69,13 +74,24 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
69
74
|
):
|
70
75
|
inputs_bsz = decoder_input_ids.shape[0]
|
71
76
|
padded_bsz = self.batch_size - inputs_bsz
|
77
|
+
|
72
78
|
if padded_bsz > 0:
|
73
79
|
decoder_input_ids = torch.nn.functional.pad(decoder_input_ids, (0, 0, 0, padded_bsz))
|
74
80
|
|
81
|
+
if self.use_attention_mask:
|
82
|
+
for b_idx in range(self.batch_size):
|
83
|
+
decoding_step = cache_position[b_idx].item()
|
84
|
+
if not (0 <= decoding_step < self.dec_max_seq_len):
|
85
|
+
raise ValueError(
|
86
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
87
|
+
)
|
88
|
+
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
89
|
+
|
75
90
|
outputs = super().forward(
|
76
|
-
decoder_input_ids
|
77
|
-
decoder_attention_mask
|
78
|
-
cache_position
|
91
|
+
decoder_input_ids,
|
92
|
+
decoder_attention_mask if self.use_attention_mask else None,
|
93
|
+
cache_position,
|
94
|
+
block_tables=self.default_block_tables,
|
79
95
|
)
|
80
96
|
|
81
97
|
if isinstance(outputs, torch.Tensor):
|
@@ -104,12 +120,15 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
104
120
|
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
105
121
|
self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
106
122
|
self.rbln_token_timestamps = self.rbln_config.model_cfg["token_timestamps"]
|
123
|
+
self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
|
107
124
|
|
108
|
-
self.encoder = RBLNRuntimeEncoder(
|
109
|
-
runtime=self.model[0], main_input_name="input_features", batch_size=self.batch_size
|
110
|
-
)
|
125
|
+
self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
|
111
126
|
self.decoder = RBLNRuntimeDecoder(
|
112
|
-
runtime=self.model[1],
|
127
|
+
runtime=self.model[1],
|
128
|
+
main_input_name="input_ids",
|
129
|
+
batch_size=self.batch_size,
|
130
|
+
dec_max_seq_len=self.dec_max_seq_len,
|
131
|
+
use_attention_mask=self.use_attention_mask,
|
113
132
|
)
|
114
133
|
|
115
134
|
# skip encoder & first decoder when language detected
|
@@ -121,6 +140,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
121
140
|
# input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
122
141
|
self.model = WhisperModel(self.config)
|
123
142
|
self.pad_token_id = self.config.pad_token_id
|
143
|
+
self.generation_config.forced_decoder_ids = None
|
124
144
|
|
125
145
|
def can_generate(self):
|
126
146
|
return True
|
@@ -152,7 +172,10 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
152
172
|
@classmethod
|
153
173
|
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
154
174
|
rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
|
155
|
-
|
175
|
+
use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
|
176
|
+
return WhisperWrapper(
|
177
|
+
model, use_attention_mask=use_attention_mask, rbln_token_timestamps=rbln_token_timestamps
|
178
|
+
)
|
156
179
|
|
157
180
|
@classmethod
|
158
181
|
@torch.inference_mode()
|
@@ -215,28 +238,34 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
215
238
|
if rbln_dec_max_seq_len is None:
|
216
239
|
rbln_dec_max_seq_len = model_config.max_length
|
217
240
|
|
218
|
-
#
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
241
|
+
# use_attention_mask conditions
|
242
|
+
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
243
|
+
if rbln_use_attention_mask is None:
|
244
|
+
rbln_use_attention_mask = False
|
245
|
+
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
246
|
+
if rbln_npu == "RBLN-CA02":
|
247
|
+
rbln_use_attention_mask = True
|
248
|
+
|
249
|
+
enc_input_info = [
|
250
|
+
("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
|
251
|
+
("block_tables", [1], "int16"),
|
252
|
+
(
|
253
|
+
"cross_key_value_states",
|
254
|
+
[
|
255
|
+
model_config.decoder_layers * 2,
|
256
|
+
rbln_batch_size,
|
257
|
+
model_config.decoder_attention_heads,
|
258
|
+
enc_max_seq_len,
|
259
|
+
model_config.d_model // model_config.decoder_attention_heads,
|
260
|
+
],
|
261
|
+
"float32",
|
262
|
+
),
|
263
|
+
]
|
235
264
|
|
236
265
|
dec_input_info = [
|
237
266
|
("decoder_input_ids", [rbln_batch_size, 1], "int64"),
|
238
|
-
("
|
239
|
-
("
|
267
|
+
("cache_position", [rbln_batch_size, 1], "int32"),
|
268
|
+
("block_tables", [rbln_batch_size, 1], "int16"),
|
240
269
|
]
|
241
270
|
dec_input_info.extend(
|
242
271
|
[
|
@@ -269,6 +298,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
269
298
|
]
|
270
299
|
)
|
271
300
|
|
301
|
+
if rbln_use_attention_mask:
|
302
|
+
dec_input_info.insert(1, ("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
|
303
|
+
|
272
304
|
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
273
305
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
274
306
|
|
@@ -283,6 +315,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
283
315
|
"batch_size": rbln_batch_size,
|
284
316
|
"dec_max_seq_len": rbln_dec_max_seq_len,
|
285
317
|
"token_timestamps": rbln_token_timestamps,
|
318
|
+
"use_attention_mask": rbln_use_attention_mask,
|
286
319
|
}
|
287
320
|
)
|
288
321
|
|
@@ -327,11 +360,25 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
327
360
|
|
328
361
|
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L512
|
329
362
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
330
|
-
self,
|
363
|
+
self,
|
364
|
+
inputs_tensor: torch.Tensor,
|
365
|
+
model_kwargs,
|
366
|
+
model_input_name: Optional[str] = None,
|
367
|
+
generation_config: Optional["GenerationConfig"] = None,
|
368
|
+
**kwargs,
|
331
369
|
) -> Dict[str, Any]:
|
370
|
+
batch_size = inputs_tensor.shape[0]
|
371
|
+
n_pad_to_batch = self.batch_size - batch_size
|
372
|
+
if n_pad_to_batch > 0:
|
373
|
+
inputs_tensor = torch.nn.functional.pad(inputs_tensor, (0, 0, 0, 0, 0, n_pad_to_batch))
|
374
|
+
|
332
375
|
if not self.is_language_detected:
|
333
|
-
|
334
|
-
|
376
|
+
for b in range(inputs_tensor.shape[0]):
|
377
|
+
block_tables = torch.tensor([b], dtype=torch.int16)
|
378
|
+
model_kwargs["encoder_outputs"] = self.encoder(
|
379
|
+
input_features=inputs_tensor[b].unsqueeze(0), block_tables=block_tables
|
380
|
+
)
|
381
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
|
335
382
|
else:
|
336
383
|
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
|
337
384
|
|
@@ -359,7 +406,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
359
406
|
decoder_output = self.decoder(
|
360
407
|
decoder_input_ids=input_ids[:, step : step + 1].contiguous(),
|
361
408
|
decoder_attention_mask=self.decoder_attention_mask,
|
362
|
-
cache_position=
|
409
|
+
cache_position=torch.full((self.batch_size, 1), step, dtype=torch.int32),
|
363
410
|
)
|
364
411
|
cross_attentions.append(decoder_output.cross_attentions)
|
365
412
|
lm_logits = decoder_output.logits
|
@@ -374,15 +421,19 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
374
421
|
# detect language pass
|
375
422
|
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/whisper/generation_whisper.py#L1442
|
376
423
|
else:
|
424
|
+
# for language auto detection (generate with language=None)
|
377
425
|
if encoder_outputs is None:
|
378
|
-
|
379
|
-
|
426
|
+
for b in range(input_features.shape[0]):
|
427
|
+
block_tables = torch.tensor([b], dtype=torch.int16)
|
428
|
+
self.encoder(input_features=input_features[b].unsqueeze(0), block_tables=block_tables)
|
429
|
+
|
430
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
|
380
431
|
self.is_language_detected = True
|
381
432
|
self.decoder_attention_mask[:, 0] = 1
|
382
433
|
decoder_output = self.decoder(
|
383
434
|
decoder_input_ids=decoder_input_ids.contiguous(),
|
384
435
|
decoder_attention_mask=self.decoder_attention_mask,
|
385
|
-
cache_position=torch.zeros([], dtype=torch.int32),
|
436
|
+
cache_position=torch.zeros([self.rbln_config.model_cfg["batch_size"], 1], dtype=torch.int32),
|
386
437
|
)
|
387
438
|
lm_logits = decoder_output.logits
|
388
439
|
self.language_cross = decoder_output.cross_attentions
|