optimum-rbln 0.7.4a0__py3-none-any.whl → 0.7.4a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,341 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import torch
27
+ from torch import nn
28
+ from transformers.modeling_attn_mask_utils import (
29
+ _prepare_4d_causal_attention_mask,
30
+ )
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutput,
33
+ Seq2SeqLMOutput,
34
+ )
35
+ from transformers.utils import logging
36
+
37
+ from ....ops import register_rbln_custom_cache_update, register_rbln_custom_paged_add_softmax_attention
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ class TimeSeriesTransformersWrapper:
44
+ def __init__(self, model, num_parallel_samples):
45
+ register_rbln_custom_cache_update()
46
+ register_rbln_custom_paged_add_softmax_attention()
47
+ self.encoder = TimeSeriesTransformersEncoderWrapper(model)
48
+ self.decoder = TimeSeriesTransformersDecoderWrapper(model, num_parallel_samples)
49
+
50
+
51
+ class TimeSeriesTransformersEncoderWrapper(torch.nn.Module):
52
+ def __init__(self, model):
53
+ super().__init__()
54
+ self.config = model.config
55
+ self.encoder = model.get_encoder()
56
+ self.num_heads = self.config.decoder_attention_heads
57
+ self.d_kv = self.config.d_model // self.num_heads
58
+ self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
59
+
60
+ def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
61
+ return (
62
+ nn.ModuleList(layer.encoder_attn.k_proj for layer in decoder_layers),
63
+ nn.ModuleList(layer.encoder_attn.v_proj for layer in decoder_layers),
64
+ )
65
+
66
+ def forward(
67
+ self,
68
+ inputs_embeds: torch.Tensor,
69
+ cross_key_values: torch.Tensor, # n_layers, batch_size, num_heads, context_length, d_kv
70
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
71
+ # 1. get encoder last_hidden_states
72
+ encoder_outputs = self.encoder(inputs_embeds=inputs_embeds, attention_mask=None, return_dict=False)
73
+ last_hidden_states = encoder_outputs[0]
74
+
75
+ # 2. pre-compute cross_attention's past_key_value which used in decoder phase.
76
+ cross_kv = []
77
+ batch_size = inputs_embeds.shape[0]
78
+ for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
79
+ past_k = k_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
80
+ past_v = v_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
81
+
82
+ cross_kv.append(past_k)
83
+ cross_kv.append(past_v)
84
+
85
+ cross_kv = torch.stack(cross_kv, dim=0)
86
+
87
+ # 3. update cross_attention's past_key_value to the device-dram for optimization.
88
+ bidx = torch.tensor(0, dtype=torch.int16)
89
+ axis = torch.tensor(1, dtype=torch.int16)
90
+ enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
91
+
92
+ return enc_output
93
+
94
+
95
+ class TimeSeriesTransformersDecoderWrapper(torch.nn.Module):
96
+ def __init__(self, model, num_parallel_samples):
97
+ super().__init__()
98
+ self.config = model.config
99
+ self.num_layers = self.config.decoder_layers
100
+ self.decoder = self.convert_to_rbln_tst_decoder(model, num_parallel_samples)
101
+ self.parameter_projection = model.parameter_projection
102
+
103
+ def convert_to_rbln_tst_decoder(self, model: nn.Module, num_parallel_samples: int):
104
+ new_layers = []
105
+ for layer in model.get_decoder().layers:
106
+ self_attn = TimeSeriesTransformersSelfAttention(layer.self_attn, num_parallel_samples)
107
+ cross_attn = TimeSeriesTransformersCrossAttention(layer.encoder_attn, num_parallel_samples)
108
+ new_layers.append(TimeSeriesTransformersDecoderLayer(layer, self_attn, cross_attn))
109
+
110
+ decoder_model = TimeSeriesTransformersDecoder(model.get_decoder(), new_layers)
111
+
112
+ return decoder_model
113
+
114
+ def forward(
115
+ self,
116
+ inputs_embeds: torch.Tensor,
117
+ decoder_attention_mask: torch.Tensor,
118
+ cache_position: torch.Tensor,
119
+ block_tables: torch.Tensor,
120
+ cross_kv_cache: torch.Tensor, # batch_size, num_heads, context_length, d_kv
121
+ *self_kv_cache: torch.Tensor, # batch_size * num_parallel_samples, num_heads, prediction_length, d_kv
122
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
123
+ # prepare past_key_values
124
+ self_past_key_values = ()
125
+ cross_past_key_values = ()
126
+ for i in range(0, self.num_layers * 2, 2):
127
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
128
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
129
+
130
+ # Decode
131
+ last_hidden_states = self.decoder(
132
+ inputs_embeds=inputs_embeds,
133
+ attention_mask=decoder_attention_mask,
134
+ cache_position=cache_position,
135
+ block_tables=block_tables,
136
+ self_past_key_values=self_past_key_values,
137
+ cross_past_key_values=cross_past_key_values,
138
+ )
139
+
140
+ params = self.parameter_projection(last_hidden_states[:, -1:])
141
+
142
+ outputs = ()
143
+ outputs += (params,)
144
+ outputs += (last_hidden_states,)
145
+
146
+ return outputs
147
+
148
+
149
+ class TimeSeriesTransformersDecoder(nn.Module):
150
+ def __init__(self, model, layers, **kwargs):
151
+ super().__init__()
152
+ self._original_mod = model
153
+ self.config = model.config
154
+ self.layers = nn.ModuleList(layers)
155
+ self.value_embedding = model.value_embedding
156
+ self.embed_positions = model.embed_positions
157
+ self.layernorm_embedding = model.layernorm_embedding
158
+
159
+ def forward(
160
+ self,
161
+ inputs_embeds: torch.Tensor = None,
162
+ attention_mask: Optional[torch.Tensor] = None,
163
+ self_past_key_values: Optional[torch.Tensor] = None,
164
+ cross_past_key_values: Optional[torch.Tensor] = None,
165
+ cache_position: Optional[torch.Tensor] = None,
166
+ block_tables: torch.Tensor = None,
167
+ ):
168
+ input_shape = inputs_embeds.size()[:-1]
169
+
170
+ # prepare casual_attn_mask
171
+ attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
172
+
173
+ hidden_states = self.value_embedding(inputs_embeds)
174
+ embed_pos = self.embed_positions.weight[cache_position + self.config.context_length]
175
+ hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
176
+
177
+ # iterate decoder_layer
178
+ for self_past_key_value, cross_past_key_value, decoder_layer in zip(
179
+ self_past_key_values, cross_past_key_values, self.layers
180
+ ):
181
+ hidden_states = decoder_layer(
182
+ hidden_states,
183
+ attention_mask=attention_mask,
184
+ # encoder_attention_mask=encoder_attention_mask,
185
+ self_past_key_value=self_past_key_value,
186
+ cross_past_key_value=cross_past_key_value,
187
+ cache_position=cache_position,
188
+ block_tables=block_tables,
189
+ )
190
+
191
+ return hidden_states
192
+
193
+
194
+ class TimeSeriesTransformersDecoderLayer(nn.Module):
195
+ def __init__(self, decoder_layer, self_attn, cross_attn):
196
+ super().__init__()
197
+ self._original_mod = decoder_layer
198
+ self.self_attn = self_attn
199
+ self.encoder_attn = cross_attn
200
+ self.embed_dim = decoder_layer.embed_dim
201
+ self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
202
+ self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
203
+ self.final_layer_norm = decoder_layer.final_layer_norm
204
+ self.activation_fn = decoder_layer.activation_fn
205
+ self.fc1 = decoder_layer.fc1
206
+ self.fc2 = decoder_layer.fc2
207
+
208
+ def forward(
209
+ self,
210
+ hidden_states: torch.Tensor,
211
+ attention_mask: Optional[torch.Tensor] = None,
212
+ self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
213
+ cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
214
+ cache_position: Optional[torch.Tensor] = None,
215
+ block_tables: torch.Tensor = None,
216
+ ) -> torch.Tensor:
217
+ # Self Attention Block
218
+ residual = hidden_states
219
+ hidden_states = self.self_attn(
220
+ hidden_states=hidden_states,
221
+ past_key_value=self_past_key_value,
222
+ attention_mask=attention_mask,
223
+ cache_position=cache_position,
224
+ block_tables=block_tables,
225
+ )
226
+ hidden_states = residual + hidden_states
227
+ hidden_states = self.self_attn_layer_norm(hidden_states)
228
+
229
+ # Cross-Attention Block
230
+ residual = hidden_states
231
+ hidden_states = self.encoder_attn(
232
+ hidden_states=hidden_states,
233
+ past_key_value=cross_past_key_value,
234
+ # attention_mask=encoder_attention_mask,
235
+ )
236
+ hidden_states = residual + hidden_states
237
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
238
+
239
+ # Fully Connected Block
240
+ residual = hidden_states
241
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
242
+ hidden_states = self.fc2(hidden_states)
243
+ hidden_states = residual + hidden_states
244
+ hidden_states = self.final_layer_norm(hidden_states)
245
+
246
+ return hidden_states
247
+
248
+
249
+ class TimeSeriesTransformersAttention(nn.Module):
250
+ def __init__(self, attn, num_parallel_samples):
251
+ super().__init__()
252
+ self._original_mod = attn
253
+ self.q_proj = attn.q_proj
254
+ self.k_proj = attn.k_proj
255
+ self.v_proj = attn.v_proj
256
+ self.out_proj = attn.out_proj
257
+ self.num_heads = attn.num_heads
258
+ self.embed_dim = attn.embed_dim
259
+ self.head_dim = attn.head_dim
260
+ self.scaling = attn.scaling
261
+ self.num_parallel_samples = num_parallel_samples
262
+
263
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
264
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
265
+
266
+
267
+ class TimeSeriesTransformersSelfAttention(TimeSeriesTransformersAttention):
268
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
269
+ return tensor.view(1, seq_len, 1, bsz * self.num_heads, self.head_dim).transpose(1, 3)
270
+
271
+ def forward(
272
+ self,
273
+ hidden_states: torch.Tensor,
274
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
275
+ attention_mask: Optional[torch.Tensor] = None,
276
+ cache_position: Optional[torch.Tensor] = None,
277
+ block_tables: Optional[torch.Tensor] = None,
278
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
279
+ bsz, tgt_len, _ = hidden_states.size()
280
+ query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
281
+ query_states = query_states * self.scaling
282
+
283
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
284
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
285
+
286
+ block_size = past_key_value[0].shape[-2]
287
+ attn_output = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode(
288
+ query_states,
289
+ key_states,
290
+ value_states,
291
+ attention_mask.unsqueeze(2),
292
+ past_key_value[0].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
293
+ past_key_value[1].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
294
+ cache_position.expand(bsz, 1),
295
+ torch.tensor(1.0, dtype=torch.float32), # scale
296
+ block_tables,
297
+ block_size,
298
+ )
299
+
300
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
301
+ attn_output = attn_output.transpose(1, 2)
302
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
303
+ attn_output = self.out_proj(attn_output)
304
+
305
+ return attn_output
306
+
307
+
308
+ class TimeSeriesTransformersCrossAttention(TimeSeriesTransformersSelfAttention):
309
+ def forward(
310
+ self,
311
+ hidden_states: torch.Tensor,
312
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
313
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
314
+ batch_size, query_len, _ = hidden_states.size()
315
+ query_states = (
316
+ self.q_proj(hidden_states)
317
+ .view(
318
+ batch_size // self.num_parallel_samples,
319
+ self.num_parallel_samples,
320
+ query_len,
321
+ self.num_heads,
322
+ self.head_dim,
323
+ )
324
+ .transpose(2, 3)
325
+ )
326
+ query_states = query_states * self.scaling
327
+
328
+ key_states = past_key_value[0].unsqueeze(1)
329
+ value_states = past_key_value[1].unsqueeze(1)
330
+
331
+ attn_weights = torch.matmul(query_states, key_states.transpose(3, 4))
332
+ attn_weights = attn_weights
333
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
334
+
335
+ attn_output = torch.matmul(attn_weights, value_states)
336
+ attn_output = attn_output.view(batch_size, self.num_heads, query_len, self.head_dim)
337
+ attn_output = attn_output.transpose(1, 2)
338
+ attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
339
+ attn_output = self.out_proj(attn_output)
340
+
341
+ return attn_output
@@ -38,24 +38,15 @@ from .whisper_architecture import WhisperWrapper
38
38
  logger = get_logger(__name__)
39
39
 
40
40
  if TYPE_CHECKING:
41
- from transformers import AutoFeatureExtractor, AutoProcessor, PretrainedConfig, PreTrainedModel
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, GenerationConfig, PretrainedConfig, PreTrainedModel
42
42
 
43
43
 
44
44
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
45
45
  mandatory_members = ["main_input_name"]
46
46
 
47
- def forward(self, input_features: torch.Tensor = None):
48
- # backward compatibility transformers==4.40.2
49
- # https://github.com/huggingface/transformers/blob/4fdf58afb72b0754da30037fc800b6044e7d9c99/src/transformers/pipelines/automatic_speech_recognition.py#L494
50
-
51
- n_pad_to_batch = self.batch_size - input_features.shape[0]
52
- if n_pad_to_batch > 0:
53
- input_features = torch.nn.functional.pad(input_features, (0, 0, 0, 0, 0, n_pad_to_batch))
54
-
55
- _ = super().forward(input_features=input_features)
56
-
57
- # dummy output for generation
58
- return BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
47
+ def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
48
+ output = super().forward(*args, **kwargs)
49
+ return BaseModelOutput(last_hidden_state=output)
59
50
 
60
51
 
61
52
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
@@ -65,10 +56,14 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
65
56
  self,
66
57
  runtime: rebel.Runtime,
67
58
  batch_size: int,
59
+ dec_max_seq_len: int,
60
+ use_attention_mask: Optional[bool] = None,
68
61
  **kwargs: Any,
69
62
  ) -> None:
70
63
  super().__init__(runtime, **kwargs)
71
64
  self.batch_size = batch_size
65
+ self.dec_max_seq_len = dec_max_seq_len
66
+ self.use_attention_mask = use_attention_mask
72
67
  self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
73
68
 
74
69
  def forward(
@@ -79,13 +74,23 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
79
74
  ):
80
75
  inputs_bsz = decoder_input_ids.shape[0]
81
76
  padded_bsz = self.batch_size - inputs_bsz
77
+
82
78
  if padded_bsz > 0:
83
79
  decoder_input_ids = torch.nn.functional.pad(decoder_input_ids, (0, 0, 0, padded_bsz))
84
80
 
81
+ if self.use_attention_mask:
82
+ for b_idx in range(self.batch_size):
83
+ decoding_step = cache_position[b_idx].item()
84
+ if not (0 <= decoding_step < self.dec_max_seq_len):
85
+ raise ValueError(
86
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
87
+ )
88
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
89
+
85
90
  outputs = super().forward(
86
- decoder_input_ids=decoder_input_ids,
87
- decoder_attention_mask=decoder_attention_mask,
88
- cache_position=cache_position,
91
+ decoder_input_ids,
92
+ decoder_attention_mask if self.use_attention_mask else None,
93
+ cache_position,
89
94
  block_tables=self.default_block_tables,
90
95
  )
91
96
 
@@ -115,12 +120,15 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
115
120
  self.batch_size = self.rbln_config.model_cfg["batch_size"]
116
121
  self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
117
122
  self.rbln_token_timestamps = self.rbln_config.model_cfg["token_timestamps"]
123
+ self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
118
124
 
119
- self.encoder = RBLNRuntimeEncoder(
120
- runtime=self.model[0], main_input_name="input_features", batch_size=self.batch_size
121
- )
125
+ self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
122
126
  self.decoder = RBLNRuntimeDecoder(
123
- runtime=self.model[1], main_input_name="input_ids", batch_size=self.batch_size
127
+ runtime=self.model[1],
128
+ main_input_name="input_ids",
129
+ batch_size=self.batch_size,
130
+ dec_max_seq_len=self.dec_max_seq_len,
131
+ use_attention_mask=self.use_attention_mask,
124
132
  )
125
133
 
126
134
  # skip encoder & first decoder when language detected
@@ -132,6 +140,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
132
140
  # input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
133
141
  self.model = WhisperModel(self.config)
134
142
  self.pad_token_id = self.config.pad_token_id
143
+ self.generation_config.forced_decoder_ids = None
135
144
 
136
145
  def can_generate(self):
137
146
  return True
@@ -163,7 +172,10 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
163
172
  @classmethod
164
173
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
165
174
  rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
166
- return WhisperWrapper(model, rbln_token_timestamps)
175
+ use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
176
+ return WhisperWrapper(
177
+ model, use_attention_mask=use_attention_mask, rbln_token_timestamps=rbln_token_timestamps
178
+ )
167
179
 
168
180
  @classmethod
169
181
  @torch.inference_mode()
@@ -226,28 +238,33 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
226
238
  if rbln_dec_max_seq_len is None:
227
239
  rbln_dec_max_seq_len = model_config.max_length
228
240
 
229
- # model input info
230
- enc_input_info = [("input_features", [rbln_batch_size, num_mel_bins, expected_seq_len], "float32")]
231
- enc_input_info.extend(
232
- [
233
- (
234
- "cross_key_value_states",
235
- [
236
- model_config.decoder_layers * 2,
237
- rbln_batch_size,
238
- model_config.decoder_attention_heads,
239
- enc_max_seq_len,
240
- model_config.d_model // model_config.decoder_attention_heads,
241
- ],
242
- "float32",
243
- )
244
- ]
245
- )
241
+ # use_attention_mask conditions
242
+ rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
243
+ if rbln_use_attention_mask is None:
244
+ rbln_use_attention_mask = False
245
+ rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
246
+ if rbln_npu == "RBLN-CA02":
247
+ rbln_use_attention_mask = True
248
+
249
+ enc_input_info = [
250
+ ("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
251
+ ("block_tables", [1], "int16"),
252
+ (
253
+ "cross_key_value_states",
254
+ [
255
+ model_config.decoder_layers * 2,
256
+ rbln_batch_size,
257
+ model_config.decoder_attention_heads,
258
+ enc_max_seq_len,
259
+ model_config.d_model // model_config.decoder_attention_heads,
260
+ ],
261
+ "float32",
262
+ ),
263
+ ]
246
264
 
247
265
  dec_input_info = [
248
266
  ("decoder_input_ids", [rbln_batch_size, 1], "int64"),
249
- ("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
250
- ("cache_position", [], "int32"),
267
+ ("cache_position", [rbln_batch_size, 1], "int32"),
251
268
  ("block_tables", [rbln_batch_size, 1], "int16"),
252
269
  ]
253
270
  dec_input_info.extend(
@@ -281,6 +298,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
281
298
  ]
282
299
  )
283
300
 
301
+ if rbln_use_attention_mask:
302
+ dec_input_info.insert(1, ("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
303
+
284
304
  enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
285
305
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
286
306
 
@@ -295,6 +315,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
295
315
  "batch_size": rbln_batch_size,
296
316
  "dec_max_seq_len": rbln_dec_max_seq_len,
297
317
  "token_timestamps": rbln_token_timestamps,
318
+ "use_attention_mask": rbln_use_attention_mask,
298
319
  }
299
320
  )
300
321
 
@@ -339,11 +360,25 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
339
360
 
340
361
  # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L512
341
362
  def _prepare_encoder_decoder_kwargs_for_generation(
342
- self, inputs_tensor: torch.Tensor, model_kwargs, *args, **kwargs
363
+ self,
364
+ inputs_tensor: torch.Tensor,
365
+ model_kwargs,
366
+ model_input_name: Optional[str] = None,
367
+ generation_config: Optional["GenerationConfig"] = None,
368
+ **kwargs,
343
369
  ) -> Dict[str, Any]:
370
+ batch_size = inputs_tensor.shape[0]
371
+ n_pad_to_batch = self.batch_size - batch_size
372
+ if n_pad_to_batch > 0:
373
+ inputs_tensor = torch.nn.functional.pad(inputs_tensor, (0, 0, 0, 0, 0, n_pad_to_batch))
374
+
344
375
  if not self.is_language_detected:
345
- model_kwargs["encoder_outputs"] = self.encoder(input_features=inputs_tensor)
346
- self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
376
+ for b in range(inputs_tensor.shape[0]):
377
+ block_tables = torch.tensor([b], dtype=torch.int16)
378
+ model_kwargs["encoder_outputs"] = self.encoder(
379
+ input_features=inputs_tensor[b].unsqueeze(0), block_tables=block_tables
380
+ )
381
+ self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
347
382
  else:
348
383
  model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
349
384
 
@@ -371,7 +406,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
371
406
  decoder_output = self.decoder(
372
407
  decoder_input_ids=input_ids[:, step : step + 1].contiguous(),
373
408
  decoder_attention_mask=self.decoder_attention_mask,
374
- cache_position=step.to(torch.int32),
409
+ cache_position=torch.full((self.batch_size, 1), step, dtype=torch.int32),
375
410
  )
376
411
  cross_attentions.append(decoder_output.cross_attentions)
377
412
  lm_logits = decoder_output.logits
@@ -386,15 +421,19 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
386
421
  # detect language pass
387
422
  # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/whisper/generation_whisper.py#L1442
388
423
  else:
424
+ # for language auto detection (generate with language=None)
389
425
  if encoder_outputs is None:
390
- self.encoder(input_features=input_features.contiguous())
391
- self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
426
+ for b in range(input_features.shape[0]):
427
+ block_tables = torch.tensor([b], dtype=torch.int16)
428
+ self.encoder(input_features=input_features[b].unsqueeze(0), block_tables=block_tables)
429
+
430
+ self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
392
431
  self.is_language_detected = True
393
432
  self.decoder_attention_mask[:, 0] = 1
394
433
  decoder_output = self.decoder(
395
434
  decoder_input_ids=decoder_input_ids.contiguous(),
396
435
  decoder_attention_mask=self.decoder_attention_mask,
397
- cache_position=torch.zeros([], dtype=torch.int32),
436
+ cache_position=torch.zeros([self.rbln_config.model_cfg["batch_size"], 1], dtype=torch.int32),
398
437
  )
399
438
  lm_logits = decoder_output.logits
400
439
  self.language_cross = decoder_output.cross_attentions