optimum-rbln 0.7.2rc1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +9 -4
  3. optimum/rbln/diffusers/__init__.py +8 -0
  4. optimum/rbln/diffusers/modeling_diffusers.py +103 -109
  5. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -3
  6. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +15 -8
  7. optimum/rbln/diffusers/pipelines/__init__.py +8 -0
  8. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +7 -1
  9. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +25 -0
  10. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +107 -1
  11. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +25 -0
  12. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +3 -0
  13. optimum/rbln/modeling.py +4 -1
  14. optimum/rbln/modeling_base.py +16 -3
  15. optimum/rbln/ops/__init__.py +6 -2
  16. optimum/rbln/ops/attn.py +94 -85
  17. optimum/rbln/ops/flash_attn.py +46 -25
  18. optimum/rbln/ops/kv_cache_update.py +4 -4
  19. optimum/rbln/transformers/modeling_generic.py +3 -3
  20. optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
  21. optimum/rbln/transformers/models/bart/modeling_bart.py +6 -2
  22. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
  23. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +264 -133
  24. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +276 -29
  25. optimum/rbln/transformers/models/exaone/exaone_architecture.py +11 -4
  26. optimum/rbln/transformers/models/gemma/gemma_architecture.py +11 -4
  27. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +5 -3
  28. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -3
  29. optimum/rbln/transformers/models/phi/phi_architecture.py +9 -7
  30. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +50 -13
  31. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +60 -36
  32. optimum/rbln/transformers/models/t5/modeling_t5.py +3 -1
  33. optimum/rbln/transformers/models/t5/t5_architecture.py +65 -3
  34. optimum/rbln/transformers/models/whisper/whisper_architecture.py +26 -36
  35. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -14
  36. optimum/rbln/utils/import_utils.py +7 -0
  37. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/METADATA +1 -1
  38. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/RECORD +40 -38
  39. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/WHEEL +0 -0
  40. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -32,11 +32,13 @@ if TYPE_CHECKING:
32
32
 
33
33
 
34
34
  class PhiWrapper(DecoderOnlyWrapper):
35
- def convert_to_rbln_causal_lm(self, causal_lm: "PhiForCausalLM"):
35
+ def convert_to_rbln_causal_lm(self, causal_lm: "PhiForCausalLM", max_seq_len: int):
36
36
  new_layers = []
37
37
  for layer in causal_lm.model.layers:
38
38
  if self.attn_impl == "eager":
39
- new_self_attn = PhiAttention(layer.self_attn)
39
+ new_self_attn = PhiAttention(
40
+ layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
41
+ )
40
42
  elif self.attn_impl == "flash_attn":
41
43
  raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
42
44
  else:
@@ -81,30 +83,30 @@ class PhiLayer(DecoderOnlyLayer):
81
83
  hidden_states: torch.Tensor,
82
84
  attention_mask: torch.Tensor,
83
85
  seq_positions: torch.LongTensor,
84
- batch_position: torch.Tensor,
85
86
  past_key_values: Tuple[Tuple[torch.Tensor]],
86
87
  cos: Optional[torch.Tensor] = None,
87
88
  sin: Optional[torch.Tensor] = None,
89
+ block_tables: Optional[torch.Tensor] = None,
88
90
  ):
89
91
  residual = hidden_states
90
92
 
91
93
  hidden_states = self.get_pre_attention_layernorm()(hidden_states)
92
94
 
93
- attn_outputs, present_key_values = self.self_attn(
95
+ attn_output = self.self_attn(
94
96
  hidden_states=hidden_states,
95
97
  attention_mask=attention_mask,
96
98
  seq_positions=seq_positions,
97
- batch_position=batch_position,
98
99
  past_key_values=past_key_values,
99
100
  cos=cos,
100
101
  sin=sin,
102
+ block_tables=block_tables,
101
103
  )
102
104
 
103
105
  feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
104
106
 
105
- hidden_states = attn_outputs + feed_forward_hidden_states + residual
107
+ hidden_states = attn_output + feed_forward_hidden_states + residual
106
108
 
107
- return hidden_states, present_key_values
109
+ return hidden_states
108
110
 
109
111
 
110
112
  class PhiModel(DecoderOnlyModel):
@@ -50,11 +50,18 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
50
50
  runtime: rebel.Runtime,
51
51
  batch_size: int,
52
52
  dec_max_seq_len: int,
53
+ support_paged_causal_attn: Optional[bool] = None,
54
+ use_attention_mask: Optional[bool] = None,
53
55
  **kwargs: Any,
54
56
  ) -> None:
55
57
  super().__init__(runtime, **kwargs)
56
58
  self.batch_size = batch_size
57
59
  self.dec_max_seq_len = dec_max_seq_len
60
+ self.use_attention_mask = use_attention_mask
61
+ if support_paged_causal_attn:
62
+ self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
63
+ else:
64
+ self.default_block_tables = None
58
65
 
59
66
  def forward(
60
67
  self,
@@ -62,6 +69,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
62
69
  attention_mask: Optional[torch.FloatTensor] = None,
63
70
  decoder_attention_mask: Optional[torch.BoolTensor] = None,
64
71
  cache_position: Optional[torch.Tensor] = None,
72
+ block_tables: Optional[torch.Tensor] = None,
65
73
  **kwargs,
66
74
  ) -> Tuple[torch.FloatTensor]:
67
75
  batch_size = decoder_input_ids.shape[0]
@@ -73,19 +81,24 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
73
81
  if batch_size != cache_position.shape[0]:
74
82
  raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
75
83
 
76
- for b_idx in range(self.batch_size):
77
- decoding_step = cache_position[b_idx].item()
78
- if not (0 <= decoding_step < self.dec_max_seq_len):
79
- raise ValueError(
80
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
81
- )
82
- decoder_attention_mask[b_idx, : decoding_step + 1] = 1
84
+ if self.use_attention_mask:
85
+ for b_idx in range(self.batch_size):
86
+ decoding_step = cache_position[b_idx].item()
87
+ if not (0 <= decoding_step < self.dec_max_seq_len):
88
+ raise ValueError(
89
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
90
+ )
91
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
92
+
93
+ if block_tables is None:
94
+ block_tables = self.default_block_tables
83
95
 
84
96
  lm_logits = super().forward(
85
97
  decoder_input_ids,
86
- decoder_attention_mask,
98
+ decoder_attention_mask if self.use_attention_mask else None,
87
99
  attention_mask,
88
100
  cache_position,
101
+ block_tables=block_tables,
89
102
  )
90
103
 
91
104
  return Seq2SeqLMOutput(logits=lm_logits)
@@ -106,16 +119,24 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
106
119
 
107
120
  main_input_name = "input_ids"
108
121
  auto_model_class = AutoModelForSeq2SeqLM
122
+ support_paged_causal_attn = None
109
123
 
110
124
  def __post_init__(self, **kwargs):
111
125
  batch_size = self.rbln_config.model_cfg["batch_size"]
112
126
  dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
127
+ self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
128
+
113
129
  self.encoder = RBLNRuntimeEncoder(
114
130
  runtime=self.model[0],
115
131
  main_input_name="input_ids",
116
132
  )
117
133
  self.decoder = RBLNRuntimeDecoder(
118
- runtime=self.model[1], main_input_name="input_ids", batch_size=batch_size, dec_max_seq_len=dec_max_seq_len
134
+ runtime=self.model[1],
135
+ main_input_name="input_ids",
136
+ batch_size=batch_size,
137
+ dec_max_seq_len=dec_max_seq_len,
138
+ support_paged_causal_attn=self.support_paged_causal_attn,
139
+ use_attention_mask=self.use_attention_mask,
119
140
  )
120
141
 
121
142
  @classmethod
@@ -172,6 +193,16 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
172
193
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
173
194
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
174
195
 
196
+ if cls.support_paged_causal_attn:
197
+ rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
198
+ if rbln_use_attention_mask is None:
199
+ rbln_use_attention_mask = False
200
+ rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
201
+ if rbln_npu == "RBLN-CA02":
202
+ rbln_use_attention_mask = True
203
+ else:
204
+ rbln_use_attention_mask = True
205
+
175
206
  n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
176
207
  n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
177
208
  d_kv = (
@@ -232,12 +263,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
232
263
  ],
233
264
  "float32",
234
265
  ),
235
- ("batch_position", [], "int16"),
266
+ ("block_tables", [1], "int16"),
236
267
  ]
237
268
 
238
269
  dec_input_info = [
239
270
  ("input_ids", [rbln_batch_size, 1], "int64"),
240
- ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
241
271
  ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
242
272
  (
243
273
  "cache_position",
@@ -275,6 +305,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
275
305
  for i in range(n_layer * 2)
276
306
  ]
277
307
  )
308
+
309
+ if cls.support_paged_causal_attn:
310
+ dec_input_info.insert(3, ("block_tables", [rbln_batch_size, 1], "int16"))
311
+ if rbln_use_attention_mask:
312
+ dec_input_info.insert(1, ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
313
+
278
314
  enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
279
315
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
280
316
 
@@ -290,6 +326,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
290
326
  "dec_max_seq_len": rbln_dec_max_seq_len,
291
327
  "batch_size": rbln_batch_size,
292
328
  "pad_token_id": rbln_pad_token_id,
329
+ "use_attention_mask": rbln_use_attention_mask,
293
330
  }
294
331
  )
295
332
 
@@ -400,9 +437,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
400
437
  encoder_kwargs["output_attentions"] = False
401
438
 
402
439
  for b in range(batch_size):
403
- batch_position = torch.tensor(b, dtype=torch.int16)
440
+ block_tables = torch.tensor([b], dtype=torch.int16)
404
441
  encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
405
442
  encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
406
- model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_position=batch_position)
443
+ model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
407
444
 
408
445
  return model_kwargs
@@ -12,13 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple
15
+ from typing import Optional, Tuple
16
16
 
17
17
  import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import register_rbln_custom_attention, register_rbln_custom_cache_update
21
+ from ....ops import (
22
+ register_rbln_custom_cache_update,
23
+ register_rbln_custom_paged_attention,
24
+ register_rbln_custom_paged_causal_attention,
25
+ )
22
26
 
23
27
 
24
28
  logger = logging.get_logger(__name__)
@@ -87,7 +91,7 @@ class Seq2SeqEncoderWrapper(nn.Module):
87
91
  input_ids: torch.Tensor,
88
92
  attention_mask: torch.Tensor,
89
93
  cross_key_values: torch.Tensor,
90
- batch_position: torch.Tensor,
94
+ b_idx: torch.Tensor,
91
95
  ) -> Tuple[torch.Tensor]:
92
96
  # 1. get encoder last_hidden_states
93
97
  encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
@@ -110,11 +114,9 @@ class Seq2SeqEncoderWrapper(nn.Module):
110
114
 
111
115
  # 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
112
116
  batch_axis = torch.tensor(1, dtype=torch.int16)
113
- cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
114
- cross_key_values, cross_kv, batch_position, batch_axis
115
- )
117
+ enc_out = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, b_idx[0], batch_axis)
116
118
 
117
- return cross_key_values
119
+ return enc_out
118
120
 
119
121
 
120
122
  class Seq2SeqDecoderWrapper(nn.Module):
@@ -131,9 +133,10 @@ class Seq2SeqDecoderWrapper(nn.Module):
131
133
  **kwargs: Additional arguments for decoder configuration.
132
134
  """
133
135
 
134
- def __init__(self, model: nn.Module, **kwargs):
136
+ def __init__(self, model: nn.Module, use_attention_mask: bool = True, **kwargs):
135
137
  super().__init__()
136
138
  self.config = model.config
139
+ self.use_attention_mask = use_attention_mask
137
140
  self.__post_init__(model, **kwargs)
138
141
 
139
142
  def __post_init__(self, model: nn.Module, **kwargs):
@@ -143,7 +146,11 @@ class Seq2SeqDecoderWrapper(nn.Module):
143
146
  It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
144
147
  by subclasses to modify or add custom attributes as necessary.
145
148
  """
146
- register_rbln_custom_attention()
149
+ if self.use_attention_mask:
150
+ register_rbln_custom_paged_attention()
151
+ else:
152
+ register_rbln_custom_paged_causal_attention()
153
+
147
154
  self.num_layers = self.config.decoder_layers
148
155
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
149
156
 
@@ -160,13 +167,23 @@ class Seq2SeqDecoderWrapper(nn.Module):
160
167
 
161
168
  def forward(
162
169
  self,
163
- input_ids: torch.Tensor,
164
- attention_mask: torch.Tensor,
165
- encoder_attention_mask: torch.Tensor,
166
- cache_position: torch.Tensor,
167
- cross_kv_cache: torch.Tensor,
168
- *self_kv_cache: torch.Tensor,
170
+ *args,
169
171
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
172
+ if self.use_attention_mask:
173
+ (
174
+ input_ids,
175
+ attention_mask,
176
+ encoder_attention_mask,
177
+ cache_position,
178
+ block_tables,
179
+ cross_kv_cache,
180
+ *self_kv_cache,
181
+ ) = args
182
+
183
+ else:
184
+ attention_mask = None
185
+ (input_ids, encoder_attention_mask, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
186
+
170
187
  self_past_key_values = ()
171
188
  cross_past_key_values = ()
172
189
  for i in range(0, self.num_layers * 2, 2):
@@ -174,18 +191,17 @@ class Seq2SeqDecoderWrapper(nn.Module):
174
191
  cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
175
192
 
176
193
  # decode
177
- lm_logits, self_present_key_values = self.conditional_generation(
194
+ lm_logits = self.conditional_generation(
178
195
  input_ids=input_ids,
179
196
  attention_mask=attention_mask,
180
197
  encoder_attention_mask=encoder_attention_mask,
181
198
  self_past_key_values=self_past_key_values,
182
199
  cross_past_key_values=cross_past_key_values,
183
200
  cache_position=cache_position,
201
+ block_tables=block_tables,
184
202
  )
185
203
 
186
- outputs = (lm_logits,) + self_present_key_values
187
-
188
- return outputs
204
+ return lm_logits
189
205
 
190
206
 
191
207
  class Seq2SeqForConditionalGeneration(nn.Module):
@@ -228,14 +244,16 @@ class Seq2SeqForConditionalGeneration(nn.Module):
228
244
  self_past_key_values,
229
245
  cross_past_key_values,
230
246
  cache_position,
247
+ block_tables: Optional[torch.Tensor] = None,
231
248
  ):
232
- hidden_states, self_present_key_values = self.decoder(
249
+ hidden_states = self.decoder(
233
250
  input_ids=input_ids,
234
251
  attention_mask=attention_mask,
235
252
  encoder_attention_mask=encoder_attention_mask,
236
253
  self_past_key_values=self_past_key_values,
237
254
  cross_past_key_values=cross_past_key_values,
238
255
  cache_position=cache_position,
256
+ block_tables=block_tables,
239
257
  )
240
258
 
241
259
  if self.has_rescaling and self.config.tie_word_embeddings:
@@ -243,7 +261,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
243
261
 
244
262
  lm_logits = self.lm_head(hidden_states)
245
263
 
246
- return lm_logits, self_present_key_values
264
+ return lm_logits
247
265
 
248
266
 
249
267
  class Seq2SeqDecoder(torch.nn.Module):
@@ -292,6 +310,7 @@ class Seq2SeqDecoder(torch.nn.Module):
292
310
  self_past_key_values: torch.Tensor,
293
311
  cross_past_key_values: torch.Tensor,
294
312
  cache_position: torch.Tensor,
313
+ block_tables: Optional[torch.Tensor] = None,
295
314
  ):
296
315
  # embedding
297
316
  hidden_states = self.get_embedding()(input_ids)
@@ -303,24 +322,23 @@ class Seq2SeqDecoder(torch.nn.Module):
303
322
  hidden_states = self.apply_position_embedding(hidden_states, cache_position)
304
323
 
305
324
  # iterate decoder_layer
306
- self_present_key_values = ()
307
325
  for decoder_layer, self_past_key_value, cross_past_key_value in zip(
308
326
  self.layers, self_past_key_values, cross_past_key_values
309
327
  ):
310
- hidden_states, self_present_key_value = decoder_layer(
328
+ hidden_states = decoder_layer(
311
329
  hidden_states,
312
330
  attention_mask=attention_mask,
313
331
  encoder_attention_mask=encoder_attention_mask,
314
332
  self_past_key_value=self_past_key_value,
315
333
  cross_past_key_value=cross_past_key_value,
316
334
  cache_position=cache_position,
335
+ block_tables=block_tables,
317
336
  )
318
- self_present_key_values += self_present_key_value
319
337
 
320
338
  if self.final_layer_norm is not None:
321
339
  hidden_states = self.final_layer_norm(hidden_states)
322
340
 
323
- return hidden_states, self_present_key_values
341
+ return hidden_states
324
342
 
325
343
 
326
344
  class Seq2SeqDecoderLayer(torch.nn.Module):
@@ -373,17 +391,19 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
373
391
  self_past_key_value: Tuple[torch.Tensor],
374
392
  cross_past_key_value: Tuple[torch.Tensor],
375
393
  cache_position: torch.Tensor,
394
+ block_tables: Optional[torch.Tensor] = None,
376
395
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
377
396
  dummy_encoder_hidden_states = torch.zeros(1, encoder_attention_mask.shape[-1])
378
397
 
379
398
  # Self Attention Block
380
399
  residual = hidden_states
381
400
  hidden_states = self.pre_self_attn_layer_norm(hidden_states)
382
- hidden_states, self_attn_past_key_value = self.self_attn(
401
+ hidden_states = self.self_attn(
383
402
  hidden_states=hidden_states,
384
403
  past_key_value=self_past_key_value,
385
404
  attention_mask=attention_mask,
386
405
  cache_position=cache_position,
406
+ block_tables=block_tables,
387
407
  )
388
408
  hidden_states = residual + hidden_states
389
409
  hidden_states = self.post_self_attn_layer_norm(hidden_states)
@@ -403,14 +423,14 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
403
423
  # Feed-Forward Block
404
424
  hidden_states = self.ff_layer(hidden_states)
405
425
 
406
- return hidden_states, self_attn_past_key_value
426
+ return hidden_states
407
427
 
408
428
 
409
429
  class Seq2SeqSelfAttention(nn.Module):
410
- def __init__(self, attn):
430
+ def __init__(self, attn, **kwargs):
411
431
  super().__init__()
412
432
  self._original_mod = attn
413
- self.__post_init__()
433
+ self.__post_init__(**kwargs)
414
434
 
415
435
  def __post_init__(self, **kwargs):
416
436
  """
@@ -442,6 +462,7 @@ class Seq2SeqSelfAttention(nn.Module):
442
462
  past_key_value: Tuple[torch.Tensor],
443
463
  attention_mask: torch.Tensor,
444
464
  cache_position: torch.Tensor,
465
+ block_tables: Optional[torch.Tensor] = None,
445
466
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
446
467
  bsz, tgt_len, _ = hidden_states.size()
447
468
 
@@ -450,23 +471,26 @@ class Seq2SeqSelfAttention(nn.Module):
450
471
  key_states = self._shape(key_states, -1, bsz)
451
472
  value_states = self._shape(value_states, -1, bsz)
452
473
 
453
- attn_output, key_states, value_states = self.attn_decode(
474
+ block_size = past_key_value[0].shape[-2]
475
+ args = [
454
476
  query_states,
455
477
  key_states,
456
478
  value_states,
457
- attention_mask.unsqueeze(
458
- 2
459
- ), # Unsqueeze group axis since CustomKernel expects it for group query attention
460
479
  past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
461
480
  past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
462
481
  cache_position,
463
482
  torch.tensor(1.0, dtype=torch.float32), # scale
464
- )
483
+ block_tables,
484
+ block_size,
485
+ ]
486
+ if attention_mask is not None:
487
+ args.insert(3, attention_mask.unsqueeze(2))
488
+
489
+ attn_output = self.attn_decode(*args)
465
490
 
466
491
  attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
467
492
  attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
468
493
 
469
494
  attn_output = self.out_proj(attn_output)
470
- present_key_value = (key_states, value_states)
471
495
 
472
- return attn_output, present_key_value
496
+ return attn_output
@@ -120,7 +120,7 @@ class RBLNT5EncoderModel(RBLNModel):
120
120
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
121
121
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
122
122
 
123
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
123
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
124
124
 
125
125
  if rbln_model_input_names is None:
126
126
  for tokenizer in preprocessors:
@@ -189,6 +189,8 @@ class RBLNT5EncoderModel(RBLNModel):
189
189
 
190
190
 
191
191
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
192
+ support_causal_paged_attn = False
193
+
192
194
  @classmethod
193
195
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
194
196
  enc_max_seq_len = rbln_config.model_cfg["enc_max_seq_len"]
@@ -18,7 +18,7 @@ import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import register_rbln_custom_attention_add_softmax
21
+ from ....ops import register_rbln_custom_add_softmax_attention
22
22
  from ..seq2seq.seq2seq_architecture import (
23
23
  Seq2SeqDecoder,
24
24
  Seq2SeqDecoderLayer,
@@ -55,7 +55,7 @@ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
55
55
 
56
56
  class T5DecoderWrapper(Seq2SeqDecoderWrapper):
57
57
  def __post_init__(self, model, dec_max_seq_len: int = None):
58
- register_rbln_custom_attention_add_softmax()
58
+ register_rbln_custom_add_softmax_attention()
59
59
  self.num_layers = self.config.num_layers
60
60
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
61
61
 
@@ -71,6 +71,34 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
71
71
 
72
72
  return new_model
73
73
 
74
+ def forward(
75
+ self,
76
+ input_ids,
77
+ attention_mask,
78
+ encoder_attention_mask,
79
+ cache_position,
80
+ cross_kv_cache,
81
+ *self_kv_cache,
82
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
83
+ self_past_key_values = ()
84
+ cross_past_key_values = ()
85
+
86
+ for i in range(0, self.num_layers * 2, 2):
87
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
88
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
89
+
90
+ # decode
91
+ lm_logits = self.conditional_generation(
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ encoder_attention_mask=encoder_attention_mask,
95
+ self_past_key_values=self_past_key_values,
96
+ cross_past_key_values=cross_past_key_values,
97
+ cache_position=cache_position,
98
+ )
99
+
100
+ return lm_logits
101
+
74
102
 
75
103
  class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
76
104
  has_rescaling = True
@@ -134,7 +162,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
134
162
  self.out_proj = self._original_mod.o
135
163
  self.num_heads = self._original_mod.n_heads
136
164
  self.head_dim = self._original_mod.key_value_proj_dim
137
- self.attn_decode = torch.ops.rbln_custom_ops.attn_decode_add_softmax
165
+ self.attn_decode = torch.ops.rbln_custom_ops.add_softmax_attn_decode
138
166
 
139
167
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
140
168
  query_states = self.q_proj(hidden_states)
@@ -142,6 +170,40 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
142
170
  value_states = self.v_proj(hidden_states)
143
171
  return query_states, key_states, value_states
144
172
 
173
+ def forward(
174
+ self,
175
+ hidden_states: torch.Tensor,
176
+ past_key_value: Tuple[torch.Tensor],
177
+ attention_mask: torch.Tensor,
178
+ cache_position: torch.Tensor,
179
+ **kwargs,
180
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
181
+ bsz, tgt_len, _ = hidden_states.size()
182
+
183
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
184
+ query_states = self._shape(query_states, tgt_len, bsz)
185
+ key_states = self._shape(key_states, -1, bsz)
186
+ value_states = self._shape(value_states, -1, bsz)
187
+
188
+ attn_output = self.attn_decode(
189
+ query_states,
190
+ key_states,
191
+ value_states,
192
+ attention_mask.unsqueeze(
193
+ 2
194
+ ), # Unsqueeze group axis since CustomKernel expects it for group query attention
195
+ past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
196
+ past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
197
+ cache_position,
198
+ torch.tensor(1.0, dtype=torch.float32), # scale
199
+ )
200
+
201
+ attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
202
+ attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
203
+
204
+ attn_output = self.out_proj(attn_output)
205
+ return attn_output
206
+
145
207
 
146
208
  class T5CrossAttention(nn.Module):
147
209
  def __init__(self, attn):