optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,341 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import torch
27
+ from torch import nn
28
+ from transformers.modeling_attn_mask_utils import (
29
+ _prepare_4d_causal_attention_mask,
30
+ )
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutput,
33
+ Seq2SeqLMOutput,
34
+ )
35
+ from transformers.utils import logging
36
+
37
+ from ....ops import register_rbln_custom_cache_update, register_rbln_custom_paged_add_softmax_attention
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ class TimeSeriesTransformersWrapper:
44
+ def __init__(self, model, num_parallel_samples):
45
+ register_rbln_custom_cache_update()
46
+ register_rbln_custom_paged_add_softmax_attention()
47
+ self.encoder = TimeSeriesTransformersEncoderWrapper(model)
48
+ self.decoder = TimeSeriesTransformersDecoderWrapper(model, num_parallel_samples)
49
+
50
+
51
+ class TimeSeriesTransformersEncoderWrapper(torch.nn.Module):
52
+ def __init__(self, model):
53
+ super().__init__()
54
+ self.config = model.config
55
+ self.encoder = model.get_encoder()
56
+ self.num_heads = self.config.decoder_attention_heads
57
+ self.d_kv = self.config.d_model // self.num_heads
58
+ self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
59
+
60
+ def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
61
+ return (
62
+ nn.ModuleList(layer.encoder_attn.k_proj for layer in decoder_layers),
63
+ nn.ModuleList(layer.encoder_attn.v_proj for layer in decoder_layers),
64
+ )
65
+
66
+ def forward(
67
+ self,
68
+ inputs_embeds: torch.Tensor,
69
+ cross_key_values: torch.Tensor, # n_layers, batch_size, num_heads, context_length, d_kv
70
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
71
+ # 1. get encoder last_hidden_states
72
+ encoder_outputs = self.encoder(inputs_embeds=inputs_embeds, attention_mask=None, return_dict=False)
73
+ last_hidden_states = encoder_outputs[0]
74
+
75
+ # 2. pre-compute cross_attention's past_key_value which used in decoder phase.
76
+ cross_kv = []
77
+ batch_size = inputs_embeds.shape[0]
78
+ for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
79
+ past_k = k_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
80
+ past_v = v_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
81
+
82
+ cross_kv.append(past_k)
83
+ cross_kv.append(past_v)
84
+
85
+ cross_kv = torch.stack(cross_kv, dim=0)
86
+
87
+ # 3. update cross_attention's past_key_value to the device-dram for optimization.
88
+ bidx = torch.tensor(0, dtype=torch.int16)
89
+ axis = torch.tensor(1, dtype=torch.int16)
90
+ enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
91
+
92
+ return enc_output
93
+
94
+
95
+ class TimeSeriesTransformersDecoderWrapper(torch.nn.Module):
96
+ def __init__(self, model, num_parallel_samples):
97
+ super().__init__()
98
+ self.config = model.config
99
+ self.num_layers = self.config.decoder_layers
100
+ self.decoder = self.convert_to_rbln_tst_decoder(model, num_parallel_samples)
101
+ self.parameter_projection = model.parameter_projection
102
+
103
+ def convert_to_rbln_tst_decoder(self, model: nn.Module, num_parallel_samples: int):
104
+ new_layers = []
105
+ for layer in model.get_decoder().layers:
106
+ self_attn = TimeSeriesTransformersSelfAttention(layer.self_attn, num_parallel_samples)
107
+ cross_attn = TimeSeriesTransformersCrossAttention(layer.encoder_attn, num_parallel_samples)
108
+ new_layers.append(TimeSeriesTransformersDecoderLayer(layer, self_attn, cross_attn))
109
+
110
+ decoder_model = TimeSeriesTransformersDecoder(model.get_decoder(), new_layers)
111
+
112
+ return decoder_model
113
+
114
+ def forward(
115
+ self,
116
+ inputs_embeds: torch.Tensor,
117
+ decoder_attention_mask: torch.Tensor,
118
+ cache_position: torch.Tensor,
119
+ block_tables: torch.Tensor,
120
+ cross_kv_cache: torch.Tensor, # batch_size, num_heads, context_length, d_kv
121
+ *self_kv_cache: torch.Tensor, # batch_size * num_parallel_samples, num_heads, prediction_length, d_kv
122
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
123
+ # prepare past_key_values
124
+ self_past_key_values = ()
125
+ cross_past_key_values = ()
126
+ for i in range(0, self.num_layers * 2, 2):
127
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
128
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
129
+
130
+ # Decode
131
+ last_hidden_states = self.decoder(
132
+ inputs_embeds=inputs_embeds,
133
+ attention_mask=decoder_attention_mask,
134
+ cache_position=cache_position,
135
+ block_tables=block_tables,
136
+ self_past_key_values=self_past_key_values,
137
+ cross_past_key_values=cross_past_key_values,
138
+ )
139
+
140
+ params = self.parameter_projection(last_hidden_states[:, -1:])
141
+
142
+ outputs = ()
143
+ outputs += (params,)
144
+ outputs += (last_hidden_states,)
145
+
146
+ return outputs
147
+
148
+
149
+ class TimeSeriesTransformersDecoder(nn.Module):
150
+ def __init__(self, model, layers, **kwargs):
151
+ super().__init__()
152
+ self._original_mod = model
153
+ self.config = model.config
154
+ self.layers = nn.ModuleList(layers)
155
+ self.value_embedding = model.value_embedding
156
+ self.embed_positions = model.embed_positions
157
+ self.layernorm_embedding = model.layernorm_embedding
158
+
159
+ def forward(
160
+ self,
161
+ inputs_embeds: torch.Tensor = None,
162
+ attention_mask: Optional[torch.Tensor] = None,
163
+ self_past_key_values: Optional[torch.Tensor] = None,
164
+ cross_past_key_values: Optional[torch.Tensor] = None,
165
+ cache_position: Optional[torch.Tensor] = None,
166
+ block_tables: torch.Tensor = None,
167
+ ):
168
+ input_shape = inputs_embeds.size()[:-1]
169
+
170
+ # prepare casual_attn_mask
171
+ attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
172
+
173
+ hidden_states = self.value_embedding(inputs_embeds)
174
+ embed_pos = self.embed_positions.weight[cache_position + self.config.context_length]
175
+ hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
176
+
177
+ # iterate decoder_layer
178
+ for self_past_key_value, cross_past_key_value, decoder_layer in zip(
179
+ self_past_key_values, cross_past_key_values, self.layers
180
+ ):
181
+ hidden_states = decoder_layer(
182
+ hidden_states,
183
+ attention_mask=attention_mask,
184
+ # encoder_attention_mask=encoder_attention_mask,
185
+ self_past_key_value=self_past_key_value,
186
+ cross_past_key_value=cross_past_key_value,
187
+ cache_position=cache_position,
188
+ block_tables=block_tables,
189
+ )
190
+
191
+ return hidden_states
192
+
193
+
194
+ class TimeSeriesTransformersDecoderLayer(nn.Module):
195
+ def __init__(self, decoder_layer, self_attn, cross_attn):
196
+ super().__init__()
197
+ self._original_mod = decoder_layer
198
+ self.self_attn = self_attn
199
+ self.encoder_attn = cross_attn
200
+ self.embed_dim = decoder_layer.embed_dim
201
+ self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
202
+ self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
203
+ self.final_layer_norm = decoder_layer.final_layer_norm
204
+ self.activation_fn = decoder_layer.activation_fn
205
+ self.fc1 = decoder_layer.fc1
206
+ self.fc2 = decoder_layer.fc2
207
+
208
+ def forward(
209
+ self,
210
+ hidden_states: torch.Tensor,
211
+ attention_mask: Optional[torch.Tensor] = None,
212
+ self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
213
+ cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
214
+ cache_position: Optional[torch.Tensor] = None,
215
+ block_tables: torch.Tensor = None,
216
+ ) -> torch.Tensor:
217
+ # Self Attention Block
218
+ residual = hidden_states
219
+ hidden_states = self.self_attn(
220
+ hidden_states=hidden_states,
221
+ past_key_value=self_past_key_value,
222
+ attention_mask=attention_mask,
223
+ cache_position=cache_position,
224
+ block_tables=block_tables,
225
+ )
226
+ hidden_states = residual + hidden_states
227
+ hidden_states = self.self_attn_layer_norm(hidden_states)
228
+
229
+ # Cross-Attention Block
230
+ residual = hidden_states
231
+ hidden_states = self.encoder_attn(
232
+ hidden_states=hidden_states,
233
+ past_key_value=cross_past_key_value,
234
+ # attention_mask=encoder_attention_mask,
235
+ )
236
+ hidden_states = residual + hidden_states
237
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
238
+
239
+ # Fully Connected Block
240
+ residual = hidden_states
241
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
242
+ hidden_states = self.fc2(hidden_states)
243
+ hidden_states = residual + hidden_states
244
+ hidden_states = self.final_layer_norm(hidden_states)
245
+
246
+ return hidden_states
247
+
248
+
249
+ class TimeSeriesTransformersAttention(nn.Module):
250
+ def __init__(self, attn, num_parallel_samples):
251
+ super().__init__()
252
+ self._original_mod = attn
253
+ self.q_proj = attn.q_proj
254
+ self.k_proj = attn.k_proj
255
+ self.v_proj = attn.v_proj
256
+ self.out_proj = attn.out_proj
257
+ self.num_heads = attn.num_heads
258
+ self.embed_dim = attn.embed_dim
259
+ self.head_dim = attn.head_dim
260
+ self.scaling = attn.scaling
261
+ self.num_parallel_samples = num_parallel_samples
262
+
263
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
264
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
265
+
266
+
267
+ class TimeSeriesTransformersSelfAttention(TimeSeriesTransformersAttention):
268
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
269
+ return tensor.view(1, seq_len, 1, bsz * self.num_heads, self.head_dim).transpose(1, 3)
270
+
271
+ def forward(
272
+ self,
273
+ hidden_states: torch.Tensor,
274
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
275
+ attention_mask: Optional[torch.Tensor] = None,
276
+ cache_position: Optional[torch.Tensor] = None,
277
+ block_tables: Optional[torch.Tensor] = None,
278
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
279
+ bsz, tgt_len, _ = hidden_states.size()
280
+ query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
281
+ query_states = query_states * self.scaling
282
+
283
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
284
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
285
+
286
+ block_size = past_key_value[0].shape[-2]
287
+ attn_output = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode(
288
+ query_states,
289
+ key_states,
290
+ value_states,
291
+ attention_mask.unsqueeze(2),
292
+ past_key_value[0].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
293
+ past_key_value[1].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
294
+ cache_position.expand(bsz, 1),
295
+ torch.tensor(1.0, dtype=torch.float32), # scale
296
+ block_tables,
297
+ block_size,
298
+ )
299
+
300
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
301
+ attn_output = attn_output.transpose(1, 2)
302
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
303
+ attn_output = self.out_proj(attn_output)
304
+
305
+ return attn_output
306
+
307
+
308
+ class TimeSeriesTransformersCrossAttention(TimeSeriesTransformersSelfAttention):
309
+ def forward(
310
+ self,
311
+ hidden_states: torch.Tensor,
312
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
313
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
314
+ batch_size, query_len, _ = hidden_states.size()
315
+ query_states = (
316
+ self.q_proj(hidden_states)
317
+ .view(
318
+ batch_size // self.num_parallel_samples,
319
+ self.num_parallel_samples,
320
+ query_len,
321
+ self.num_heads,
322
+ self.head_dim,
323
+ )
324
+ .transpose(2, 3)
325
+ )
326
+ query_states = query_states * self.scaling
327
+
328
+ key_states = past_key_value[0].unsqueeze(1)
329
+ value_states = past_key_value[1].unsqueeze(1)
330
+
331
+ attn_weights = torch.matmul(query_states, key_states.transpose(3, 4))
332
+ attn_weights = attn_weights
333
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
334
+
335
+ attn_output = torch.matmul(attn_weights, value_states)
336
+ attn_output = attn_output.view(batch_size, self.num_heads, query_len, self.head_dim)
337
+ attn_output = attn_output.transpose(1, 2)
338
+ attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
339
+ attn_output = self.out_proj(attn_output)
340
+
341
+ return attn_output
@@ -38,29 +38,34 @@ from .whisper_architecture import WhisperWrapper
38
38
  logger = get_logger(__name__)
39
39
 
40
40
  if TYPE_CHECKING:
41
- from transformers import AutoFeatureExtractor, AutoProcessor, PretrainedConfig, PreTrainedModel
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, GenerationConfig, PretrainedConfig, PreTrainedModel
42
42
 
43
43
 
44
44
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
45
45
  mandatory_members = ["main_input_name"]
46
46
 
47
- def forward(self, input_features: torch.Tensor = None):
48
- # backward compatibility transformers==4.40.2
49
- # https://github.com/huggingface/transformers/blob/4fdf58afb72b0754da30037fc800b6044e7d9c99/src/transformers/pipelines/automatic_speech_recognition.py#L494
50
-
51
- n_pad_to_batch = self.batch_size - input_features.shape[0]
52
- if n_pad_to_batch > 0:
53
- input_features = torch.nn.functional.pad(input_features, (0, 0, 0, 0, 0, n_pad_to_batch))
54
-
55
- _ = super().forward(input_features=input_features)
56
-
57
- # dummy output for generation
58
- return BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
47
+ def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
48
+ output = super().forward(*args, **kwargs)
49
+ return BaseModelOutput(last_hidden_state=output)
59
50
 
60
51
 
61
52
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
62
53
  mandatory_members = ["main_input_name"]
63
54
 
55
+ def __init__(
56
+ self,
57
+ runtime: rebel.Runtime,
58
+ batch_size: int,
59
+ dec_max_seq_len: int,
60
+ use_attention_mask: Optional[bool] = None,
61
+ **kwargs: Any,
62
+ ) -> None:
63
+ super().__init__(runtime, **kwargs)
64
+ self.batch_size = batch_size
65
+ self.dec_max_seq_len = dec_max_seq_len
66
+ self.use_attention_mask = use_attention_mask
67
+ self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
68
+
64
69
  def forward(
65
70
  self,
66
71
  decoder_input_ids: torch.Tensor = None,
@@ -69,13 +74,24 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
69
74
  ):
70
75
  inputs_bsz = decoder_input_ids.shape[0]
71
76
  padded_bsz = self.batch_size - inputs_bsz
77
+
72
78
  if padded_bsz > 0:
73
79
  decoder_input_ids = torch.nn.functional.pad(decoder_input_ids, (0, 0, 0, padded_bsz))
74
80
 
81
+ if self.use_attention_mask:
82
+ for b_idx in range(self.batch_size):
83
+ decoding_step = cache_position[b_idx].item()
84
+ if not (0 <= decoding_step < self.dec_max_seq_len):
85
+ raise ValueError(
86
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
87
+ )
88
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
89
+
75
90
  outputs = super().forward(
76
- decoder_input_ids=decoder_input_ids,
77
- decoder_attention_mask=decoder_attention_mask,
78
- cache_position=cache_position,
91
+ decoder_input_ids,
92
+ decoder_attention_mask if self.use_attention_mask else None,
93
+ cache_position,
94
+ block_tables=self.default_block_tables,
79
95
  )
80
96
 
81
97
  if isinstance(outputs, torch.Tensor):
@@ -104,12 +120,15 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
104
120
  self.batch_size = self.rbln_config.model_cfg["batch_size"]
105
121
  self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
106
122
  self.rbln_token_timestamps = self.rbln_config.model_cfg["token_timestamps"]
123
+ self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
107
124
 
108
- self.encoder = RBLNRuntimeEncoder(
109
- runtime=self.model[0], main_input_name="input_features", batch_size=self.batch_size
110
- )
125
+ self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
111
126
  self.decoder = RBLNRuntimeDecoder(
112
- runtime=self.model[1], main_input_name="input_ids", batch_size=self.batch_size
127
+ runtime=self.model[1],
128
+ main_input_name="input_ids",
129
+ batch_size=self.batch_size,
130
+ dec_max_seq_len=self.dec_max_seq_len,
131
+ use_attention_mask=self.use_attention_mask,
113
132
  )
114
133
 
115
134
  # skip encoder & first decoder when language detected
@@ -121,6 +140,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
121
140
  # input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
122
141
  self.model = WhisperModel(self.config)
123
142
  self.pad_token_id = self.config.pad_token_id
143
+ self.generation_config.forced_decoder_ids = None
124
144
 
125
145
  def can_generate(self):
126
146
  return True
@@ -152,7 +172,10 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
152
172
  @classmethod
153
173
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
154
174
  rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
155
- return WhisperWrapper(model, rbln_token_timestamps)
175
+ use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
176
+ return WhisperWrapper(
177
+ model, use_attention_mask=use_attention_mask, rbln_token_timestamps=rbln_token_timestamps
178
+ )
156
179
 
157
180
  @classmethod
158
181
  @torch.inference_mode()
@@ -215,28 +238,34 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
215
238
  if rbln_dec_max_seq_len is None:
216
239
  rbln_dec_max_seq_len = model_config.max_length
217
240
 
218
- # model input info
219
- enc_input_info = [("input_features", [rbln_batch_size, num_mel_bins, expected_seq_len], "float32")]
220
- enc_input_info.extend(
221
- [
222
- (
223
- "cross_key_value_states",
224
- [
225
- model_config.decoder_layers * 2,
226
- rbln_batch_size,
227
- model_config.decoder_attention_heads,
228
- enc_max_seq_len,
229
- model_config.d_model // model_config.decoder_attention_heads,
230
- ],
231
- "float32",
232
- )
233
- ]
234
- )
241
+ # use_attention_mask conditions
242
+ rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
243
+ if rbln_use_attention_mask is None:
244
+ rbln_use_attention_mask = False
245
+ rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
246
+ if rbln_npu == "RBLN-CA02":
247
+ rbln_use_attention_mask = True
248
+
249
+ enc_input_info = [
250
+ ("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
251
+ ("block_tables", [1], "int16"),
252
+ (
253
+ "cross_key_value_states",
254
+ [
255
+ model_config.decoder_layers * 2,
256
+ rbln_batch_size,
257
+ model_config.decoder_attention_heads,
258
+ enc_max_seq_len,
259
+ model_config.d_model // model_config.decoder_attention_heads,
260
+ ],
261
+ "float32",
262
+ ),
263
+ ]
235
264
 
236
265
  dec_input_info = [
237
266
  ("decoder_input_ids", [rbln_batch_size, 1], "int64"),
238
- ("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
239
- ("cache_position", [], "int32"),
267
+ ("cache_position", [rbln_batch_size, 1], "int32"),
268
+ ("block_tables", [rbln_batch_size, 1], "int16"),
240
269
  ]
241
270
  dec_input_info.extend(
242
271
  [
@@ -269,6 +298,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
269
298
  ]
270
299
  )
271
300
 
301
+ if rbln_use_attention_mask:
302
+ dec_input_info.insert(1, ("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
303
+
272
304
  enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
273
305
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
274
306
 
@@ -283,6 +315,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
283
315
  "batch_size": rbln_batch_size,
284
316
  "dec_max_seq_len": rbln_dec_max_seq_len,
285
317
  "token_timestamps": rbln_token_timestamps,
318
+ "use_attention_mask": rbln_use_attention_mask,
286
319
  }
287
320
  )
288
321
 
@@ -327,11 +360,25 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
327
360
 
328
361
  # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L512
329
362
  def _prepare_encoder_decoder_kwargs_for_generation(
330
- self, inputs_tensor: torch.Tensor, model_kwargs, *args, **kwargs
363
+ self,
364
+ inputs_tensor: torch.Tensor,
365
+ model_kwargs,
366
+ model_input_name: Optional[str] = None,
367
+ generation_config: Optional["GenerationConfig"] = None,
368
+ **kwargs,
331
369
  ) -> Dict[str, Any]:
370
+ batch_size = inputs_tensor.shape[0]
371
+ n_pad_to_batch = self.batch_size - batch_size
372
+ if n_pad_to_batch > 0:
373
+ inputs_tensor = torch.nn.functional.pad(inputs_tensor, (0, 0, 0, 0, 0, n_pad_to_batch))
374
+
332
375
  if not self.is_language_detected:
333
- model_kwargs["encoder_outputs"] = self.encoder(input_features=inputs_tensor)
334
- self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
376
+ for b in range(inputs_tensor.shape[0]):
377
+ block_tables = torch.tensor([b], dtype=torch.int16)
378
+ model_kwargs["encoder_outputs"] = self.encoder(
379
+ input_features=inputs_tensor[b].unsqueeze(0), block_tables=block_tables
380
+ )
381
+ self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
335
382
  else:
336
383
  model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
337
384
 
@@ -359,7 +406,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
359
406
  decoder_output = self.decoder(
360
407
  decoder_input_ids=input_ids[:, step : step + 1].contiguous(),
361
408
  decoder_attention_mask=self.decoder_attention_mask,
362
- cache_position=step.to(torch.int32),
409
+ cache_position=torch.full((self.batch_size, 1), step, dtype=torch.int32),
363
410
  )
364
411
  cross_attentions.append(decoder_output.cross_attentions)
365
412
  lm_logits = decoder_output.logits
@@ -374,15 +421,19 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
374
421
  # detect language pass
375
422
  # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/whisper/generation_whisper.py#L1442
376
423
  else:
424
+ # for language auto detection (generate with language=None)
377
425
  if encoder_outputs is None:
378
- self.encoder(input_features=input_features.contiguous())
379
- self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
426
+ for b in range(input_features.shape[0]):
427
+ block_tables = torch.tensor([b], dtype=torch.int16)
428
+ self.encoder(input_features=input_features[b].unsqueeze(0), block_tables=block_tables)
429
+
430
+ self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
380
431
  self.is_language_detected = True
381
432
  self.decoder_attention_mask[:, 0] = 1
382
433
  decoder_output = self.decoder(
383
434
  decoder_input_ids=decoder_input_ids.contiguous(),
384
435
  decoder_attention_mask=self.decoder_attention_mask,
385
- cache_position=torch.zeros([], dtype=torch.int32),
436
+ cache_position=torch.zeros([self.rbln_config.model_cfg["batch_size"], 1], dtype=torch.int32),
386
437
  )
387
438
  lm_logits = decoder_output.logits
388
439
  self.language_cross = decoder_output.cross_attentions