optimum-rbln 0.7.4a0__py3-none-any.whl → 0.7.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +2 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/linear.py +25 -0
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +2 -3
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +20 -17
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +2 -2
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +24 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +422 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +341 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +86 -47
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +62 -27
- {optimum_rbln-0.7.4a0.dist-info → optimum_rbln-0.7.4a1.dist-info}/METADATA +5 -5
- {optimum_rbln-0.7.4a0.dist-info → optimum_rbln-0.7.4a1.dist-info}/RECORD +18 -14
- {optimum_rbln-0.7.4a0.dist-info → optimum_rbln-0.7.4a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a0.dist-info → optimum_rbln-0.7.4a1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py
ADDED
@@ -0,0 +1,341 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from typing import Optional, Tuple, Union
|
25
|
+
|
26
|
+
import torch
|
27
|
+
from torch import nn
|
28
|
+
from transformers.modeling_attn_mask_utils import (
|
29
|
+
_prepare_4d_causal_attention_mask,
|
30
|
+
)
|
31
|
+
from transformers.modeling_outputs import (
|
32
|
+
BaseModelOutput,
|
33
|
+
Seq2SeqLMOutput,
|
34
|
+
)
|
35
|
+
from transformers.utils import logging
|
36
|
+
|
37
|
+
from ....ops import register_rbln_custom_cache_update, register_rbln_custom_paged_add_softmax_attention
|
38
|
+
|
39
|
+
|
40
|
+
logger = logging.get_logger(__name__)
|
41
|
+
|
42
|
+
|
43
|
+
class TimeSeriesTransformersWrapper:
|
44
|
+
def __init__(self, model, num_parallel_samples):
|
45
|
+
register_rbln_custom_cache_update()
|
46
|
+
register_rbln_custom_paged_add_softmax_attention()
|
47
|
+
self.encoder = TimeSeriesTransformersEncoderWrapper(model)
|
48
|
+
self.decoder = TimeSeriesTransformersDecoderWrapper(model, num_parallel_samples)
|
49
|
+
|
50
|
+
|
51
|
+
class TimeSeriesTransformersEncoderWrapper(torch.nn.Module):
|
52
|
+
def __init__(self, model):
|
53
|
+
super().__init__()
|
54
|
+
self.config = model.config
|
55
|
+
self.encoder = model.get_encoder()
|
56
|
+
self.num_heads = self.config.decoder_attention_heads
|
57
|
+
self.d_kv = self.config.d_model // self.num_heads
|
58
|
+
self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
|
59
|
+
|
60
|
+
def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
|
61
|
+
return (
|
62
|
+
nn.ModuleList(layer.encoder_attn.k_proj for layer in decoder_layers),
|
63
|
+
nn.ModuleList(layer.encoder_attn.v_proj for layer in decoder_layers),
|
64
|
+
)
|
65
|
+
|
66
|
+
def forward(
|
67
|
+
self,
|
68
|
+
inputs_embeds: torch.Tensor,
|
69
|
+
cross_key_values: torch.Tensor, # n_layers, batch_size, num_heads, context_length, d_kv
|
70
|
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
71
|
+
# 1. get encoder last_hidden_states
|
72
|
+
encoder_outputs = self.encoder(inputs_embeds=inputs_embeds, attention_mask=None, return_dict=False)
|
73
|
+
last_hidden_states = encoder_outputs[0]
|
74
|
+
|
75
|
+
# 2. pre-compute cross_attention's past_key_value which used in decoder phase.
|
76
|
+
cross_kv = []
|
77
|
+
batch_size = inputs_embeds.shape[0]
|
78
|
+
for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
|
79
|
+
past_k = k_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
|
80
|
+
past_v = v_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
|
81
|
+
|
82
|
+
cross_kv.append(past_k)
|
83
|
+
cross_kv.append(past_v)
|
84
|
+
|
85
|
+
cross_kv = torch.stack(cross_kv, dim=0)
|
86
|
+
|
87
|
+
# 3. update cross_attention's past_key_value to the device-dram for optimization.
|
88
|
+
bidx = torch.tensor(0, dtype=torch.int16)
|
89
|
+
axis = torch.tensor(1, dtype=torch.int16)
|
90
|
+
enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
|
91
|
+
|
92
|
+
return enc_output
|
93
|
+
|
94
|
+
|
95
|
+
class TimeSeriesTransformersDecoderWrapper(torch.nn.Module):
|
96
|
+
def __init__(self, model, num_parallel_samples):
|
97
|
+
super().__init__()
|
98
|
+
self.config = model.config
|
99
|
+
self.num_layers = self.config.decoder_layers
|
100
|
+
self.decoder = self.convert_to_rbln_tst_decoder(model, num_parallel_samples)
|
101
|
+
self.parameter_projection = model.parameter_projection
|
102
|
+
|
103
|
+
def convert_to_rbln_tst_decoder(self, model: nn.Module, num_parallel_samples: int):
|
104
|
+
new_layers = []
|
105
|
+
for layer in model.get_decoder().layers:
|
106
|
+
self_attn = TimeSeriesTransformersSelfAttention(layer.self_attn, num_parallel_samples)
|
107
|
+
cross_attn = TimeSeriesTransformersCrossAttention(layer.encoder_attn, num_parallel_samples)
|
108
|
+
new_layers.append(TimeSeriesTransformersDecoderLayer(layer, self_attn, cross_attn))
|
109
|
+
|
110
|
+
decoder_model = TimeSeriesTransformersDecoder(model.get_decoder(), new_layers)
|
111
|
+
|
112
|
+
return decoder_model
|
113
|
+
|
114
|
+
def forward(
|
115
|
+
self,
|
116
|
+
inputs_embeds: torch.Tensor,
|
117
|
+
decoder_attention_mask: torch.Tensor,
|
118
|
+
cache_position: torch.Tensor,
|
119
|
+
block_tables: torch.Tensor,
|
120
|
+
cross_kv_cache: torch.Tensor, # batch_size, num_heads, context_length, d_kv
|
121
|
+
*self_kv_cache: torch.Tensor, # batch_size * num_parallel_samples, num_heads, prediction_length, d_kv
|
122
|
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
123
|
+
# prepare past_key_values
|
124
|
+
self_past_key_values = ()
|
125
|
+
cross_past_key_values = ()
|
126
|
+
for i in range(0, self.num_layers * 2, 2):
|
127
|
+
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
128
|
+
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
129
|
+
|
130
|
+
# Decode
|
131
|
+
last_hidden_states = self.decoder(
|
132
|
+
inputs_embeds=inputs_embeds,
|
133
|
+
attention_mask=decoder_attention_mask,
|
134
|
+
cache_position=cache_position,
|
135
|
+
block_tables=block_tables,
|
136
|
+
self_past_key_values=self_past_key_values,
|
137
|
+
cross_past_key_values=cross_past_key_values,
|
138
|
+
)
|
139
|
+
|
140
|
+
params = self.parameter_projection(last_hidden_states[:, -1:])
|
141
|
+
|
142
|
+
outputs = ()
|
143
|
+
outputs += (params,)
|
144
|
+
outputs += (last_hidden_states,)
|
145
|
+
|
146
|
+
return outputs
|
147
|
+
|
148
|
+
|
149
|
+
class TimeSeriesTransformersDecoder(nn.Module):
|
150
|
+
def __init__(self, model, layers, **kwargs):
|
151
|
+
super().__init__()
|
152
|
+
self._original_mod = model
|
153
|
+
self.config = model.config
|
154
|
+
self.layers = nn.ModuleList(layers)
|
155
|
+
self.value_embedding = model.value_embedding
|
156
|
+
self.embed_positions = model.embed_positions
|
157
|
+
self.layernorm_embedding = model.layernorm_embedding
|
158
|
+
|
159
|
+
def forward(
|
160
|
+
self,
|
161
|
+
inputs_embeds: torch.Tensor = None,
|
162
|
+
attention_mask: Optional[torch.Tensor] = None,
|
163
|
+
self_past_key_values: Optional[torch.Tensor] = None,
|
164
|
+
cross_past_key_values: Optional[torch.Tensor] = None,
|
165
|
+
cache_position: Optional[torch.Tensor] = None,
|
166
|
+
block_tables: torch.Tensor = None,
|
167
|
+
):
|
168
|
+
input_shape = inputs_embeds.size()[:-1]
|
169
|
+
|
170
|
+
# prepare casual_attn_mask
|
171
|
+
attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
|
172
|
+
|
173
|
+
hidden_states = self.value_embedding(inputs_embeds)
|
174
|
+
embed_pos = self.embed_positions.weight[cache_position + self.config.context_length]
|
175
|
+
hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
|
176
|
+
|
177
|
+
# iterate decoder_layer
|
178
|
+
for self_past_key_value, cross_past_key_value, decoder_layer in zip(
|
179
|
+
self_past_key_values, cross_past_key_values, self.layers
|
180
|
+
):
|
181
|
+
hidden_states = decoder_layer(
|
182
|
+
hidden_states,
|
183
|
+
attention_mask=attention_mask,
|
184
|
+
# encoder_attention_mask=encoder_attention_mask,
|
185
|
+
self_past_key_value=self_past_key_value,
|
186
|
+
cross_past_key_value=cross_past_key_value,
|
187
|
+
cache_position=cache_position,
|
188
|
+
block_tables=block_tables,
|
189
|
+
)
|
190
|
+
|
191
|
+
return hidden_states
|
192
|
+
|
193
|
+
|
194
|
+
class TimeSeriesTransformersDecoderLayer(nn.Module):
|
195
|
+
def __init__(self, decoder_layer, self_attn, cross_attn):
|
196
|
+
super().__init__()
|
197
|
+
self._original_mod = decoder_layer
|
198
|
+
self.self_attn = self_attn
|
199
|
+
self.encoder_attn = cross_attn
|
200
|
+
self.embed_dim = decoder_layer.embed_dim
|
201
|
+
self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
|
202
|
+
self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
|
203
|
+
self.final_layer_norm = decoder_layer.final_layer_norm
|
204
|
+
self.activation_fn = decoder_layer.activation_fn
|
205
|
+
self.fc1 = decoder_layer.fc1
|
206
|
+
self.fc2 = decoder_layer.fc2
|
207
|
+
|
208
|
+
def forward(
|
209
|
+
self,
|
210
|
+
hidden_states: torch.Tensor,
|
211
|
+
attention_mask: Optional[torch.Tensor] = None,
|
212
|
+
self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
213
|
+
cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
214
|
+
cache_position: Optional[torch.Tensor] = None,
|
215
|
+
block_tables: torch.Tensor = None,
|
216
|
+
) -> torch.Tensor:
|
217
|
+
# Self Attention Block
|
218
|
+
residual = hidden_states
|
219
|
+
hidden_states = self.self_attn(
|
220
|
+
hidden_states=hidden_states,
|
221
|
+
past_key_value=self_past_key_value,
|
222
|
+
attention_mask=attention_mask,
|
223
|
+
cache_position=cache_position,
|
224
|
+
block_tables=block_tables,
|
225
|
+
)
|
226
|
+
hidden_states = residual + hidden_states
|
227
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
228
|
+
|
229
|
+
# Cross-Attention Block
|
230
|
+
residual = hidden_states
|
231
|
+
hidden_states = self.encoder_attn(
|
232
|
+
hidden_states=hidden_states,
|
233
|
+
past_key_value=cross_past_key_value,
|
234
|
+
# attention_mask=encoder_attention_mask,
|
235
|
+
)
|
236
|
+
hidden_states = residual + hidden_states
|
237
|
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
238
|
+
|
239
|
+
# Fully Connected Block
|
240
|
+
residual = hidden_states
|
241
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
242
|
+
hidden_states = self.fc2(hidden_states)
|
243
|
+
hidden_states = residual + hidden_states
|
244
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
245
|
+
|
246
|
+
return hidden_states
|
247
|
+
|
248
|
+
|
249
|
+
class TimeSeriesTransformersAttention(nn.Module):
|
250
|
+
def __init__(self, attn, num_parallel_samples):
|
251
|
+
super().__init__()
|
252
|
+
self._original_mod = attn
|
253
|
+
self.q_proj = attn.q_proj
|
254
|
+
self.k_proj = attn.k_proj
|
255
|
+
self.v_proj = attn.v_proj
|
256
|
+
self.out_proj = attn.out_proj
|
257
|
+
self.num_heads = attn.num_heads
|
258
|
+
self.embed_dim = attn.embed_dim
|
259
|
+
self.head_dim = attn.head_dim
|
260
|
+
self.scaling = attn.scaling
|
261
|
+
self.num_parallel_samples = num_parallel_samples
|
262
|
+
|
263
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
264
|
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
265
|
+
|
266
|
+
|
267
|
+
class TimeSeriesTransformersSelfAttention(TimeSeriesTransformersAttention):
|
268
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
269
|
+
return tensor.view(1, seq_len, 1, bsz * self.num_heads, self.head_dim).transpose(1, 3)
|
270
|
+
|
271
|
+
def forward(
|
272
|
+
self,
|
273
|
+
hidden_states: torch.Tensor,
|
274
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
275
|
+
attention_mask: Optional[torch.Tensor] = None,
|
276
|
+
cache_position: Optional[torch.Tensor] = None,
|
277
|
+
block_tables: Optional[torch.Tensor] = None,
|
278
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
279
|
+
bsz, tgt_len, _ = hidden_states.size()
|
280
|
+
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
281
|
+
query_states = query_states * self.scaling
|
282
|
+
|
283
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
284
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
285
|
+
|
286
|
+
block_size = past_key_value[0].shape[-2]
|
287
|
+
attn_output = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode(
|
288
|
+
query_states,
|
289
|
+
key_states,
|
290
|
+
value_states,
|
291
|
+
attention_mask.unsqueeze(2),
|
292
|
+
past_key_value[0].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
|
293
|
+
past_key_value[1].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
|
294
|
+
cache_position.expand(bsz, 1),
|
295
|
+
torch.tensor(1.0, dtype=torch.float32), # scale
|
296
|
+
block_tables,
|
297
|
+
block_size,
|
298
|
+
)
|
299
|
+
|
300
|
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
301
|
+
attn_output = attn_output.transpose(1, 2)
|
302
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
303
|
+
attn_output = self.out_proj(attn_output)
|
304
|
+
|
305
|
+
return attn_output
|
306
|
+
|
307
|
+
|
308
|
+
class TimeSeriesTransformersCrossAttention(TimeSeriesTransformersSelfAttention):
|
309
|
+
def forward(
|
310
|
+
self,
|
311
|
+
hidden_states: torch.Tensor,
|
312
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
313
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
314
|
+
batch_size, query_len, _ = hidden_states.size()
|
315
|
+
query_states = (
|
316
|
+
self.q_proj(hidden_states)
|
317
|
+
.view(
|
318
|
+
batch_size // self.num_parallel_samples,
|
319
|
+
self.num_parallel_samples,
|
320
|
+
query_len,
|
321
|
+
self.num_heads,
|
322
|
+
self.head_dim,
|
323
|
+
)
|
324
|
+
.transpose(2, 3)
|
325
|
+
)
|
326
|
+
query_states = query_states * self.scaling
|
327
|
+
|
328
|
+
key_states = past_key_value[0].unsqueeze(1)
|
329
|
+
value_states = past_key_value[1].unsqueeze(1)
|
330
|
+
|
331
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(3, 4))
|
332
|
+
attn_weights = attn_weights
|
333
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
334
|
+
|
335
|
+
attn_output = torch.matmul(attn_weights, value_states)
|
336
|
+
attn_output = attn_output.view(batch_size, self.num_heads, query_len, self.head_dim)
|
337
|
+
attn_output = attn_output.transpose(1, 2)
|
338
|
+
attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
|
339
|
+
attn_output = self.out_proj(attn_output)
|
340
|
+
|
341
|
+
return attn_output
|
@@ -38,24 +38,15 @@ from .whisper_architecture import WhisperWrapper
|
|
38
38
|
logger = get_logger(__name__)
|
39
39
|
|
40
40
|
if TYPE_CHECKING:
|
41
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, PretrainedConfig, PreTrainedModel
|
41
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, GenerationConfig, PretrainedConfig, PreTrainedModel
|
42
42
|
|
43
43
|
|
44
44
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
45
45
|
mandatory_members = ["main_input_name"]
|
46
46
|
|
47
|
-
def forward(self,
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
n_pad_to_batch = self.batch_size - input_features.shape[0]
|
52
|
-
if n_pad_to_batch > 0:
|
53
|
-
input_features = torch.nn.functional.pad(input_features, (0, 0, 0, 0, 0, n_pad_to_batch))
|
54
|
-
|
55
|
-
_ = super().forward(input_features=input_features)
|
56
|
-
|
57
|
-
# dummy output for generation
|
58
|
-
return BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
|
47
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
48
|
+
output = super().forward(*args, **kwargs)
|
49
|
+
return BaseModelOutput(last_hidden_state=output)
|
59
50
|
|
60
51
|
|
61
52
|
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
@@ -65,10 +56,14 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
65
56
|
self,
|
66
57
|
runtime: rebel.Runtime,
|
67
58
|
batch_size: int,
|
59
|
+
dec_max_seq_len: int,
|
60
|
+
use_attention_mask: Optional[bool] = None,
|
68
61
|
**kwargs: Any,
|
69
62
|
) -> None:
|
70
63
|
super().__init__(runtime, **kwargs)
|
71
64
|
self.batch_size = batch_size
|
65
|
+
self.dec_max_seq_len = dec_max_seq_len
|
66
|
+
self.use_attention_mask = use_attention_mask
|
72
67
|
self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
|
73
68
|
|
74
69
|
def forward(
|
@@ -79,13 +74,23 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
79
74
|
):
|
80
75
|
inputs_bsz = decoder_input_ids.shape[0]
|
81
76
|
padded_bsz = self.batch_size - inputs_bsz
|
77
|
+
|
82
78
|
if padded_bsz > 0:
|
83
79
|
decoder_input_ids = torch.nn.functional.pad(decoder_input_ids, (0, 0, 0, padded_bsz))
|
84
80
|
|
81
|
+
if self.use_attention_mask:
|
82
|
+
for b_idx in range(self.batch_size):
|
83
|
+
decoding_step = cache_position[b_idx].item()
|
84
|
+
if not (0 <= decoding_step < self.dec_max_seq_len):
|
85
|
+
raise ValueError(
|
86
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
87
|
+
)
|
88
|
+
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
89
|
+
|
85
90
|
outputs = super().forward(
|
86
|
-
decoder_input_ids
|
87
|
-
decoder_attention_mask
|
88
|
-
cache_position
|
91
|
+
decoder_input_ids,
|
92
|
+
decoder_attention_mask if self.use_attention_mask else None,
|
93
|
+
cache_position,
|
89
94
|
block_tables=self.default_block_tables,
|
90
95
|
)
|
91
96
|
|
@@ -115,12 +120,15 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
115
120
|
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
116
121
|
self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
117
122
|
self.rbln_token_timestamps = self.rbln_config.model_cfg["token_timestamps"]
|
123
|
+
self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
|
118
124
|
|
119
|
-
self.encoder = RBLNRuntimeEncoder(
|
120
|
-
runtime=self.model[0], main_input_name="input_features", batch_size=self.batch_size
|
121
|
-
)
|
125
|
+
self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
|
122
126
|
self.decoder = RBLNRuntimeDecoder(
|
123
|
-
runtime=self.model[1],
|
127
|
+
runtime=self.model[1],
|
128
|
+
main_input_name="input_ids",
|
129
|
+
batch_size=self.batch_size,
|
130
|
+
dec_max_seq_len=self.dec_max_seq_len,
|
131
|
+
use_attention_mask=self.use_attention_mask,
|
124
132
|
)
|
125
133
|
|
126
134
|
# skip encoder & first decoder when language detected
|
@@ -132,6 +140,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
132
140
|
# input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
133
141
|
self.model = WhisperModel(self.config)
|
134
142
|
self.pad_token_id = self.config.pad_token_id
|
143
|
+
self.generation_config.forced_decoder_ids = None
|
135
144
|
|
136
145
|
def can_generate(self):
|
137
146
|
return True
|
@@ -163,7 +172,10 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
163
172
|
@classmethod
|
164
173
|
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
165
174
|
rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
|
166
|
-
|
175
|
+
use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
|
176
|
+
return WhisperWrapper(
|
177
|
+
model, use_attention_mask=use_attention_mask, rbln_token_timestamps=rbln_token_timestamps
|
178
|
+
)
|
167
179
|
|
168
180
|
@classmethod
|
169
181
|
@torch.inference_mode()
|
@@ -226,28 +238,33 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
226
238
|
if rbln_dec_max_seq_len is None:
|
227
239
|
rbln_dec_max_seq_len = model_config.max_length
|
228
240
|
|
229
|
-
#
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
241
|
+
# use_attention_mask conditions
|
242
|
+
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
243
|
+
if rbln_use_attention_mask is None:
|
244
|
+
rbln_use_attention_mask = False
|
245
|
+
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
246
|
+
if rbln_npu == "RBLN-CA02":
|
247
|
+
rbln_use_attention_mask = True
|
248
|
+
|
249
|
+
enc_input_info = [
|
250
|
+
("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
|
251
|
+
("block_tables", [1], "int16"),
|
252
|
+
(
|
253
|
+
"cross_key_value_states",
|
254
|
+
[
|
255
|
+
model_config.decoder_layers * 2,
|
256
|
+
rbln_batch_size,
|
257
|
+
model_config.decoder_attention_heads,
|
258
|
+
enc_max_seq_len,
|
259
|
+
model_config.d_model // model_config.decoder_attention_heads,
|
260
|
+
],
|
261
|
+
"float32",
|
262
|
+
),
|
263
|
+
]
|
246
264
|
|
247
265
|
dec_input_info = [
|
248
266
|
("decoder_input_ids", [rbln_batch_size, 1], "int64"),
|
249
|
-
("
|
250
|
-
("cache_position", [], "int32"),
|
267
|
+
("cache_position", [rbln_batch_size, 1], "int32"),
|
251
268
|
("block_tables", [rbln_batch_size, 1], "int16"),
|
252
269
|
]
|
253
270
|
dec_input_info.extend(
|
@@ -281,6 +298,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
281
298
|
]
|
282
299
|
)
|
283
300
|
|
301
|
+
if rbln_use_attention_mask:
|
302
|
+
dec_input_info.insert(1, ("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
|
303
|
+
|
284
304
|
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
285
305
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
286
306
|
|
@@ -295,6 +315,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
295
315
|
"batch_size": rbln_batch_size,
|
296
316
|
"dec_max_seq_len": rbln_dec_max_seq_len,
|
297
317
|
"token_timestamps": rbln_token_timestamps,
|
318
|
+
"use_attention_mask": rbln_use_attention_mask,
|
298
319
|
}
|
299
320
|
)
|
300
321
|
|
@@ -339,11 +360,25 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
339
360
|
|
340
361
|
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L512
|
341
362
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
342
|
-
self,
|
363
|
+
self,
|
364
|
+
inputs_tensor: torch.Tensor,
|
365
|
+
model_kwargs,
|
366
|
+
model_input_name: Optional[str] = None,
|
367
|
+
generation_config: Optional["GenerationConfig"] = None,
|
368
|
+
**kwargs,
|
343
369
|
) -> Dict[str, Any]:
|
370
|
+
batch_size = inputs_tensor.shape[0]
|
371
|
+
n_pad_to_batch = self.batch_size - batch_size
|
372
|
+
if n_pad_to_batch > 0:
|
373
|
+
inputs_tensor = torch.nn.functional.pad(inputs_tensor, (0, 0, 0, 0, 0, n_pad_to_batch))
|
374
|
+
|
344
375
|
if not self.is_language_detected:
|
345
|
-
|
346
|
-
|
376
|
+
for b in range(inputs_tensor.shape[0]):
|
377
|
+
block_tables = torch.tensor([b], dtype=torch.int16)
|
378
|
+
model_kwargs["encoder_outputs"] = self.encoder(
|
379
|
+
input_features=inputs_tensor[b].unsqueeze(0), block_tables=block_tables
|
380
|
+
)
|
381
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
|
347
382
|
else:
|
348
383
|
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
|
349
384
|
|
@@ -371,7 +406,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
371
406
|
decoder_output = self.decoder(
|
372
407
|
decoder_input_ids=input_ids[:, step : step + 1].contiguous(),
|
373
408
|
decoder_attention_mask=self.decoder_attention_mask,
|
374
|
-
cache_position=
|
409
|
+
cache_position=torch.full((self.batch_size, 1), step, dtype=torch.int32),
|
375
410
|
)
|
376
411
|
cross_attentions.append(decoder_output.cross_attentions)
|
377
412
|
lm_logits = decoder_output.logits
|
@@ -386,15 +421,19 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
386
421
|
# detect language pass
|
387
422
|
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/whisper/generation_whisper.py#L1442
|
388
423
|
else:
|
424
|
+
# for language auto detection (generate with language=None)
|
389
425
|
if encoder_outputs is None:
|
390
|
-
|
391
|
-
|
426
|
+
for b in range(input_features.shape[0]):
|
427
|
+
block_tables = torch.tensor([b], dtype=torch.int16)
|
428
|
+
self.encoder(input_features=input_features[b].unsqueeze(0), block_tables=block_tables)
|
429
|
+
|
430
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
|
392
431
|
self.is_language_detected = True
|
393
432
|
self.decoder_attention_mask[:, 0] = 1
|
394
433
|
decoder_output = self.decoder(
|
395
434
|
decoder_input_ids=decoder_input_ids.contiguous(),
|
396
435
|
decoder_attention_mask=self.decoder_attention_mask,
|
397
|
-
cache_position=torch.zeros([], dtype=torch.int32),
|
436
|
+
cache_position=torch.zeros([self.rbln_config.model_cfg["batch_size"], 1], dtype=torch.int32),
|
398
437
|
)
|
399
438
|
lm_logits = decoder_output.logits
|
400
439
|
self.language_cross = decoder_output.cross_attentions
|