optimum-rbln 0.7.3a1__py3-none-any.whl → 0.7.3a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,13 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple
15
+ from typing import Optional, Tuple
16
16
 
17
17
  import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import register_rbln_custom_cache_update, register_rbln_custom_masked_attention
21
+ from ....ops import (
22
+ register_rbln_custom_cache_update,
23
+ register_rbln_custom_paged_attention,
24
+ register_rbln_custom_paged_causal_attention,
25
+ )
22
26
 
23
27
 
24
28
  logger = logging.get_logger(__name__)
@@ -87,7 +91,7 @@ class Seq2SeqEncoderWrapper(nn.Module):
87
91
  input_ids: torch.Tensor,
88
92
  attention_mask: torch.Tensor,
89
93
  cross_key_values: torch.Tensor,
90
- batch_position: torch.Tensor,
94
+ b_idx: torch.Tensor,
91
95
  ) -> Tuple[torch.Tensor]:
92
96
  # 1. get encoder last_hidden_states
93
97
  encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
@@ -111,7 +115,7 @@ class Seq2SeqEncoderWrapper(nn.Module):
111
115
  # 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
112
116
  batch_axis = torch.tensor(1, dtype=torch.int16)
113
117
  cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
114
- cross_key_values, cross_kv, batch_position, batch_axis
118
+ cross_key_values, cross_kv, b_idx[0], batch_axis
115
119
  )
116
120
 
117
121
  return cross_key_values
@@ -131,9 +135,10 @@ class Seq2SeqDecoderWrapper(nn.Module):
131
135
  **kwargs: Additional arguments for decoder configuration.
132
136
  """
133
137
 
134
- def __init__(self, model: nn.Module, **kwargs):
138
+ def __init__(self, model: nn.Module, use_attention_mask: bool = True, **kwargs):
135
139
  super().__init__()
136
140
  self.config = model.config
141
+ self.use_attention_mask = use_attention_mask
137
142
  self.__post_init__(model, **kwargs)
138
143
 
139
144
  def __post_init__(self, model: nn.Module, **kwargs):
@@ -143,7 +148,11 @@ class Seq2SeqDecoderWrapper(nn.Module):
143
148
  It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
144
149
  by subclasses to modify or add custom attributes as necessary.
145
150
  """
146
- register_rbln_custom_masked_attention()
151
+ if self.use_attention_mask:
152
+ register_rbln_custom_paged_attention()
153
+ else:
154
+ register_rbln_custom_paged_causal_attention()
155
+
147
156
  self.num_layers = self.config.decoder_layers
148
157
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
149
158
 
@@ -160,13 +169,23 @@ class Seq2SeqDecoderWrapper(nn.Module):
160
169
 
161
170
  def forward(
162
171
  self,
163
- input_ids: torch.Tensor,
164
- attention_mask: torch.Tensor,
165
- encoder_attention_mask: torch.Tensor,
166
- cache_position: torch.Tensor,
167
- cross_kv_cache: torch.Tensor,
168
- *self_kv_cache: torch.Tensor,
172
+ *args,
169
173
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
174
+ if self.use_attention_mask:
175
+ (
176
+ input_ids,
177
+ attention_mask,
178
+ encoder_attention_mask,
179
+ cache_position,
180
+ block_tables,
181
+ cross_kv_cache,
182
+ *self_kv_cache,
183
+ ) = args
184
+
185
+ else:
186
+ attention_mask = None
187
+ (input_ids, encoder_attention_mask, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
188
+
170
189
  self_past_key_values = ()
171
190
  cross_past_key_values = ()
172
191
  for i in range(0, self.num_layers * 2, 2):
@@ -181,6 +200,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
181
200
  self_past_key_values=self_past_key_values,
182
201
  cross_past_key_values=cross_past_key_values,
183
202
  cache_position=cache_position,
203
+ block_tables=block_tables,
184
204
  )
185
205
 
186
206
  outputs = (lm_logits,) + self_present_key_values
@@ -228,6 +248,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
228
248
  self_past_key_values,
229
249
  cross_past_key_values,
230
250
  cache_position,
251
+ block_tables: Optional[torch.Tensor] = None,
231
252
  ):
232
253
  hidden_states, self_present_key_values = self.decoder(
233
254
  input_ids=input_ids,
@@ -236,6 +257,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
236
257
  self_past_key_values=self_past_key_values,
237
258
  cross_past_key_values=cross_past_key_values,
238
259
  cache_position=cache_position,
260
+ block_tables=block_tables,
239
261
  )
240
262
 
241
263
  if self.has_rescaling and self.config.tie_word_embeddings:
@@ -292,6 +314,7 @@ class Seq2SeqDecoder(torch.nn.Module):
292
314
  self_past_key_values: torch.Tensor,
293
315
  cross_past_key_values: torch.Tensor,
294
316
  cache_position: torch.Tensor,
317
+ block_tables: Optional[torch.Tensor] = None,
295
318
  ):
296
319
  # embedding
297
320
  hidden_states = self.get_embedding()(input_ids)
@@ -314,6 +337,7 @@ class Seq2SeqDecoder(torch.nn.Module):
314
337
  self_past_key_value=self_past_key_value,
315
338
  cross_past_key_value=cross_past_key_value,
316
339
  cache_position=cache_position,
340
+ block_tables=block_tables,
317
341
  )
318
342
  self_present_key_values += self_present_key_value
319
343
 
@@ -373,6 +397,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
373
397
  self_past_key_value: Tuple[torch.Tensor],
374
398
  cross_past_key_value: Tuple[torch.Tensor],
375
399
  cache_position: torch.Tensor,
400
+ block_tables: Optional[torch.Tensor] = None,
376
401
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
377
402
  dummy_encoder_hidden_states = torch.zeros(1, encoder_attention_mask.shape[-1])
378
403
 
@@ -384,6 +409,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
384
409
  past_key_value=self_past_key_value,
385
410
  attention_mask=attention_mask,
386
411
  cache_position=cache_position,
412
+ block_tables=block_tables,
387
413
  )
388
414
  hidden_states = residual + hidden_states
389
415
  hidden_states = self.post_self_attn_layer_norm(hidden_states)
@@ -407,10 +433,10 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
407
433
 
408
434
 
409
435
  class Seq2SeqSelfAttention(nn.Module):
410
- def __init__(self, attn):
436
+ def __init__(self, attn, **kwargs):
411
437
  super().__init__()
412
438
  self._original_mod = attn
413
- self.__post_init__()
439
+ self.__post_init__(**kwargs)
414
440
 
415
441
  def __post_init__(self, **kwargs):
416
442
  """
@@ -442,6 +468,7 @@ class Seq2SeqSelfAttention(nn.Module):
442
468
  past_key_value: Tuple[torch.Tensor],
443
469
  attention_mask: torch.Tensor,
444
470
  cache_position: torch.Tensor,
471
+ block_tables: Optional[torch.Tensor] = None,
445
472
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
446
473
  bsz, tgt_len, _ = hidden_states.size()
447
474
 
@@ -450,18 +477,22 @@ class Seq2SeqSelfAttention(nn.Module):
450
477
  key_states = self._shape(key_states, -1, bsz)
451
478
  value_states = self._shape(value_states, -1, bsz)
452
479
 
453
- attn_output, key_states, value_states = self.attn_decode(
480
+ block_size = past_key_value[0].shape[-2]
481
+ args = [
454
482
  query_states,
455
483
  key_states,
456
484
  value_states,
457
- attention_mask.unsqueeze(
458
- 2
459
- ), # Unsqueeze group axis since CustomKernel expects it for group query attention
460
485
  past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
461
486
  past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
462
487
  cache_position,
463
488
  torch.tensor(1.0, dtype=torch.float32), # scale
464
- )
489
+ block_tables,
490
+ block_size,
491
+ ]
492
+ if attention_mask is not None:
493
+ args.insert(3, attention_mask.unsqueeze(2))
494
+
495
+ attn_output, key_states, value_states = self.attn_decode(*args)
465
496
 
466
497
  attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
467
498
  attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
@@ -13,8 +13,9 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
17
17
 
18
+ import rebel
18
19
  import torch
19
20
  from transformers import (
20
21
  AutoModelForTextEncoding,
@@ -22,7 +23,7 @@ from transformers import (
22
23
  T5EncoderModel,
23
24
  T5ForConditionalGeneration,
24
25
  )
25
- from transformers.modeling_outputs import BaseModelOutput
26
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
26
27
 
27
28
  from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
28
29
  from ....modeling import RBLNModel
@@ -57,6 +58,63 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
57
58
  )
58
59
 
59
60
 
61
+ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
62
+ mandatory_members = ["main_input_name"]
63
+
64
+ def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
65
+ _ = super().forward(*args, **kwargs)
66
+ return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
67
+
68
+
69
+ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
70
+ mandatory_members = ["main_input_name"]
71
+
72
+ def __init__(
73
+ self,
74
+ runtime: rebel.Runtime,
75
+ batch_size: int,
76
+ dec_max_seq_len: int,
77
+ **kwargs: Any,
78
+ ) -> None:
79
+ super().__init__(runtime, **kwargs)
80
+ self.batch_size = batch_size
81
+ self.dec_max_seq_len = dec_max_seq_len
82
+
83
+ def forward(
84
+ self,
85
+ decoder_input_ids: Optional[torch.LongTensor] = None,
86
+ attention_mask: Optional[torch.FloatTensor] = None,
87
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
88
+ cache_position: Optional[torch.Tensor] = None,
89
+ **kwargs,
90
+ ) -> Tuple[torch.FloatTensor]:
91
+ batch_size = decoder_input_ids.shape[0]
92
+ if batch_size != self.batch_size:
93
+ raise RuntimeError(
94
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
95
+ )
96
+
97
+ if batch_size != cache_position.shape[0]:
98
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
99
+
100
+ for b_idx in range(self.batch_size):
101
+ decoding_step = cache_position[b_idx].item()
102
+ if not (0 <= decoding_step < self.dec_max_seq_len):
103
+ raise ValueError(
104
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
105
+ )
106
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
107
+
108
+ lm_logits = super().forward(
109
+ decoder_input_ids,
110
+ decoder_attention_mask,
111
+ attention_mask,
112
+ cache_position,
113
+ )
114
+
115
+ return Seq2SeqLMOutput(logits=lm_logits)
116
+
117
+
60
118
  class T5EncoderWrapper(torch.nn.Module):
61
119
  def __init__(self, model: "T5EncoderModel") -> None:
62
120
  super().__init__()
@@ -189,6 +247,21 @@ class RBLNT5EncoderModel(RBLNModel):
189
247
 
190
248
 
191
249
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
250
+ def __post_init__(self, **kwargs):
251
+ batch_size = self.rbln_config.model_cfg["batch_size"]
252
+ dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
253
+
254
+ self.encoder = RBLNRuntimeEncoder(
255
+ runtime=self.model[0],
256
+ main_input_name="input_ids",
257
+ )
258
+ self.decoder = RBLNRuntimeDecoder(
259
+ runtime=self.model[1],
260
+ main_input_name="input_ids",
261
+ batch_size=batch_size,
262
+ dec_max_seq_len=dec_max_seq_len,
263
+ )
264
+
192
265
  @classmethod
193
266
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
194
267
  enc_max_seq_len = rbln_config.model_cfg["enc_max_seq_len"]
@@ -206,3 +279,139 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
206
279
  return redirect(val)
207
280
 
208
281
  return val
282
+
283
+ @classmethod
284
+ def _get_rbln_config(
285
+ cls,
286
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
287
+ model_config: "PretrainedConfig",
288
+ rbln_kwargs: Dict[str, Any] = {},
289
+ ) -> RBLNConfig:
290
+ rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
291
+ rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
292
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
293
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
294
+
295
+ n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
296
+ n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
297
+ d_kv = (
298
+ model_config.d_kv
299
+ if hasattr(model_config, "d_kv")
300
+ else model_config.d_model // model_config.encoder_attention_heads
301
+ )
302
+
303
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
304
+ model_config, "max_position_embeddings", None
305
+ )
306
+
307
+ rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
308
+ if rbln_pad_token_id is None:
309
+ rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
310
+ if rbln_pad_token_id is None:
311
+ rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
312
+ if rbln_pad_token_id is None:
313
+ rbln_pad_token_id = -1
314
+
315
+ if rbln_enc_max_seq_len is None:
316
+ rbln_enc_max_seq_len = max_position_embeddings
317
+ if rbln_enc_max_seq_len is None:
318
+ for tokenizer in preprocessors:
319
+ if hasattr(tokenizer, "model_max_length"):
320
+ rbln_enc_max_seq_len = tokenizer.model_max_length
321
+ break
322
+ if rbln_enc_max_seq_len is None:
323
+ raise ValueError("`rbln_enc_max_seq_len` should be specified!")
324
+ if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
325
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
326
+
327
+ if rbln_dec_max_seq_len is None:
328
+ rbln_dec_max_seq_len = max_position_embeddings
329
+ if rbln_dec_max_seq_len is None:
330
+ for tokenizer in preprocessors:
331
+ if hasattr(tokenizer, "model_max_length"):
332
+ rbln_dec_max_seq_len = tokenizer.model_max_length
333
+ break
334
+ if rbln_dec_max_seq_len is None:
335
+ raise ValueError("`rbln_dec_max_seq_len` should be specified!")
336
+
337
+ if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
338
+ raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
339
+
340
+ # model input info
341
+ enc_input_info = [
342
+ ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
343
+ ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
344
+ (
345
+ "cross_key_value_states",
346
+ [
347
+ n_layer * 2,
348
+ rbln_batch_size,
349
+ n_head,
350
+ rbln_enc_max_seq_len,
351
+ d_kv,
352
+ ],
353
+ "float32",
354
+ ),
355
+ ("block_tables", [1], "int16"),
356
+ ]
357
+
358
+ dec_input_info = [
359
+ ("input_ids", [rbln_batch_size, 1], "int64"),
360
+ ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
361
+ ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
362
+ (
363
+ "cache_position",
364
+ [rbln_batch_size, 1],
365
+ "int32",
366
+ ),
367
+ ]
368
+ dec_input_info.extend(
369
+ [
370
+ (
371
+ "cross_key_value_states",
372
+ [
373
+ n_layer * 2,
374
+ rbln_batch_size,
375
+ n_head,
376
+ rbln_enc_max_seq_len,
377
+ d_kv,
378
+ ],
379
+ "float32",
380
+ )
381
+ ]
382
+ )
383
+ dec_input_info.extend(
384
+ [
385
+ (
386
+ f"self_key_value_states_{i}",
387
+ [
388
+ rbln_batch_size,
389
+ n_head,
390
+ rbln_dec_max_seq_len,
391
+ d_kv,
392
+ ],
393
+ "float32",
394
+ )
395
+ for i in range(n_layer * 2)
396
+ ]
397
+ )
398
+
399
+ enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
400
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
401
+
402
+ rbln_config = RBLNConfig(
403
+ rbln_cls=cls.__name__,
404
+ compile_cfgs=[enc_compile_config, dec_compile_config],
405
+ rbln_kwargs=rbln_kwargs,
406
+ )
407
+
408
+ rbln_config.model_cfg.update(
409
+ {
410
+ "enc_max_seq_len": rbln_enc_max_seq_len,
411
+ "dec_max_seq_len": rbln_dec_max_seq_len,
412
+ "batch_size": rbln_batch_size,
413
+ "pad_token_id": rbln_pad_token_id,
414
+ }
415
+ )
416
+
417
+ return rbln_config
@@ -18,7 +18,7 @@ import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import register_rbln_custom_attention_add_softmax
21
+ from ....ops import register_rbln_custom_add_softmax_attention
22
22
  from ..seq2seq.seq2seq_architecture import (
23
23
  Seq2SeqDecoder,
24
24
  Seq2SeqDecoderLayer,
@@ -55,7 +55,7 @@ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
55
55
 
56
56
  class T5DecoderWrapper(Seq2SeqDecoderWrapper):
57
57
  def __post_init__(self, model, dec_max_seq_len: int = None):
58
- register_rbln_custom_attention_add_softmax()
58
+ register_rbln_custom_add_softmax_attention()
59
59
  self.num_layers = self.config.num_layers
60
60
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
61
61
 
@@ -71,6 +71,36 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
71
71
 
72
72
  return new_model
73
73
 
74
+ def forward(
75
+ self,
76
+ input_ids,
77
+ attention_mask,
78
+ encoder_attention_mask,
79
+ cache_position,
80
+ cross_kv_cache,
81
+ *self_kv_cache,
82
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
83
+ self_past_key_values = ()
84
+ cross_past_key_values = ()
85
+
86
+ for i in range(0, self.num_layers * 2, 2):
87
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
88
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
89
+
90
+ # decode
91
+ lm_logits, self_present_key_values = self.conditional_generation(
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ encoder_attention_mask=encoder_attention_mask,
95
+ self_past_key_values=self_past_key_values,
96
+ cross_past_key_values=cross_past_key_values,
97
+ cache_position=cache_position,
98
+ )
99
+
100
+ outputs = (lm_logits,) + self_present_key_values
101
+
102
+ return outputs
103
+
74
104
 
75
105
  class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
76
106
  has_rescaling = True
@@ -134,7 +164,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
134
164
  self.out_proj = self._original_mod.o
135
165
  self.num_heads = self._original_mod.n_heads
136
166
  self.head_dim = self._original_mod.key_value_proj_dim
137
- self.attn_decode = torch.ops.rbln_custom_ops.attn_decode_add_softmax
167
+ self.attn_decode = torch.ops.rbln_custom_ops.add_softmax_attn_decode
138
168
 
139
169
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
140
170
  query_states = self.q_proj(hidden_states)
@@ -142,6 +172,42 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
142
172
  value_states = self.v_proj(hidden_states)
143
173
  return query_states, key_states, value_states
144
174
 
175
+ def forward(
176
+ self,
177
+ hidden_states: torch.Tensor,
178
+ past_key_value: Tuple[torch.Tensor],
179
+ attention_mask: torch.Tensor,
180
+ cache_position: torch.Tensor,
181
+ **kwargs,
182
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
183
+ bsz, tgt_len, _ = hidden_states.size()
184
+
185
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
186
+ query_states = self._shape(query_states, tgt_len, bsz)
187
+ key_states = self._shape(key_states, -1, bsz)
188
+ value_states = self._shape(value_states, -1, bsz)
189
+
190
+ attn_output, key_states, value_states = self.attn_decode(
191
+ query_states,
192
+ key_states,
193
+ value_states,
194
+ attention_mask.unsqueeze(
195
+ 2
196
+ ), # Unsqueeze group axis since CustomKernel expects it for group query attention
197
+ past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
198
+ past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
199
+ cache_position,
200
+ torch.tensor(1.0, dtype=torch.float32), # scale
201
+ )
202
+
203
+ attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
204
+ attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
205
+
206
+ attn_output = self.out_proj(attn_output)
207
+ present_key_value = (key_states, value_states)
208
+
209
+ return attn_output, present_key_value
210
+
145
211
 
146
212
  class T5CrossAttention(nn.Module):
147
213
  def __init__(self, attn):
@@ -25,7 +25,7 @@ from transformers.modeling_outputs import (
25
25
  )
26
26
  from transformers.utils import logging
27
27
 
28
- from ....ops import register_rbln_custom_cache_update
28
+ from ....ops import register_rbln_custom_add_softmax_attention, register_rbln_custom_cache_update
29
29
 
30
30
 
31
31
  logger = logging.get_logger(__name__)
@@ -34,6 +34,7 @@ logger = logging.get_logger(__name__)
34
34
  class WhisperWrapper:
35
35
  def __init__(self, model, rbln_token_timestamps):
36
36
  register_rbln_custom_cache_update()
37
+ register_rbln_custom_add_softmax_attention()
37
38
  self.encoder = WhisperEncoderWrapper(model)
38
39
  self.decoder = WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps)
39
40
 
@@ -213,7 +214,7 @@ class WhisperDecoderLayer(nn.Module):
213
214
  # Self Attention Block
214
215
  residual = hidden_states
215
216
  hidden_states = self.self_attn_layer_norm(hidden_states)
216
- hidden_states, _, self_present_key_value = self.self_attn(
217
+ hidden_states, self_present_key_value = self.self_attn(
217
218
  hidden_states=hidden_states,
218
219
  past_key_value=self_past_key_value,
219
220
  attention_mask=attention_mask,
@@ -224,7 +225,7 @@ class WhisperDecoderLayer(nn.Module):
224
225
  # Cross-Attention Block
225
226
  residual = hidden_states
226
227
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
227
- hidden_states, cross_attn_weights, cross_present_key_value = self.encoder_attn(
228
+ hidden_states, cross_attn_weights = self.encoder_attn(
228
229
  hidden_states=hidden_states,
229
230
  past_key_value=cross_past_key_value,
230
231
  )
@@ -258,19 +259,8 @@ class WhisperAttention(nn.Module):
258
259
 
259
260
 
260
261
  class WhisperSelfAttention(WhisperAttention):
261
- def rbln_cache_update(
262
- self,
263
- past_key_value: torch.Tensor,
264
- key_states: torch.Tensor,
265
- value_states: torch.Tensor,
266
- cache_position: torch.Tensor,
267
- ):
268
- s_idx = torch.tensor(cache_position, dtype=torch.int16)
269
- axis = torch.tensor(2, dtype=torch.int16)
270
-
271
- key_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[0], key_states, s_idx, axis)
272
- value_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[1], value_states, s_idx, axis)
273
- return key_states, value_states
262
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
263
+ return tensor.view(bsz, seq_len, 1, self.num_heads, self.head_dim).transpose(1, 3)
274
264
 
275
265
  def forward(
276
266
  self,
@@ -285,22 +275,27 @@ class WhisperSelfAttention(WhisperAttention):
285
275
 
286
276
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
287
277
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
288
- key_states, value_states = self.rbln_cache_update(past_key_value, key_states, value_states, cache_position)
289
278
 
290
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
291
- attn_weights = attn_weights + attention_mask
292
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
279
+ attn_output, key_states, value_states = torch.ops.rbln_custom_ops.add_softmax_attn_decode(
280
+ query_states,
281
+ key_states,
282
+ value_states,
283
+ attention_mask.unsqueeze(2),
284
+ past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
285
+ past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
286
+ cache_position.expand(bsz, 1),
287
+ torch.tensor(1.0, dtype=torch.float32), # scale
288
+ )
293
289
 
294
- attn_output = torch.matmul(attn_weights, value_states)
295
290
  attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
296
291
  attn_output = attn_output.transpose(1, 2)
297
292
  attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
298
293
  attn_output = self.out_proj(attn_output)
299
294
 
300
- return attn_output, attn_weights, (key_states, value_states)
295
+ return attn_output, (key_states, value_states)
301
296
 
302
297
 
303
- class WhisperCrossAttention(WhisperSelfAttention):
298
+ class WhisperCrossAttention(WhisperAttention):
304
299
  def forward(
305
300
  self,
306
301
  hidden_states: torch.Tensor,
@@ -322,4 +317,4 @@ class WhisperCrossAttention(WhisperSelfAttention):
322
317
  attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
323
318
  attn_output = self.out_proj(attn_output)
324
319
 
325
- return attn_output, attn_weights, (key_states, value_states)
320
+ return attn_output, attn_weights
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.3a1
3
+ Version: 0.7.3a3
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai