optimum-rbln 0.7.2rc1__py3-none-any.whl → 0.7.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. optimum/rbln/__version__.py +9 -4
  2. optimum/rbln/diffusers/modeling_diffusers.py +18 -12
  3. optimum/rbln/modeling.py +1 -1
  4. optimum/rbln/modeling_base.py +15 -3
  5. optimum/rbln/ops/__init__.py +6 -2
  6. optimum/rbln/ops/attn.py +95 -7
  7. optimum/rbln/ops/flash_attn.py +43 -6
  8. optimum/rbln/transformers/modeling_generic.py +3 -3
  9. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -1
  10. optimum/rbln/transformers/models/bart/modeling_bart.py +1 -1
  11. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
  12. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +186 -78
  13. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +55 -17
  14. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -3
  15. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -3
  16. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +3 -3
  17. optimum/rbln/transformers/models/midm/midm_architecture.py +3 -3
  18. optimum/rbln/transformers/models/phi/phi_architecture.py +2 -2
  19. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  20. optimum/rbln/transformers/models/t5/modeling_t5.py +1 -1
  21. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -1
  22. optimum/rbln/utils/import_utils.py +7 -0
  23. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3a0.dist-info}/METADATA +1 -1
  24. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3a0.dist-info}/RECORD +26 -26
  25. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3a0.dist-info}/WHEEL +0 -0
  26. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3a0.dist-info}/licenses/LICENSE +0 -0
@@ -19,7 +19,12 @@ import torch
19
19
  from torch import nn
20
20
  from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
- from ....ops import register_rbln_custom_attention, register_rbln_custom_flash_attention
22
+ from ....ops import (
23
+ register_rbln_custom_causal_masked_attention,
24
+ register_rbln_custom_flash_causal_masked_attention,
25
+ register_rbln_custom_flash_masked_attention,
26
+ register_rbln_custom_masked_attention,
27
+ )
23
28
  from ....utils import logging
24
29
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
25
30
 
@@ -128,6 +133,7 @@ class DecoderOnlyWrapper(nn.Module):
128
133
  max_seq_len: int,
129
134
  use_rotary_emb: bool,
130
135
  attn_impl: str,
136
+ use_attention_mask: bool,
131
137
  kvcache_partition_len: Optional[int] = None,
132
138
  ):
133
139
  super().__init__()
@@ -139,12 +145,19 @@ class DecoderOnlyWrapper(nn.Module):
139
145
  self.rotary_emb = None
140
146
 
141
147
  self.attn_impl = attn_impl
148
+ self.use_attention_mask = use_attention_mask
142
149
  if self.attn_impl == "flash_attn":
143
150
  self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
144
- register_rbln_custom_flash_attention()
151
+ if self.use_attention_mask:
152
+ register_rbln_custom_flash_masked_attention()
153
+ else:
154
+ register_rbln_custom_flash_causal_masked_attention()
145
155
  elif self.attn_impl == "eager":
146
156
  self.kvcache_partition_len = None
147
- register_rbln_custom_attention()
157
+ if self.use_attention_mask:
158
+ register_rbln_custom_masked_attention()
159
+ else:
160
+ register_rbln_custom_causal_masked_attention()
148
161
  else:
149
162
  raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
150
163
 
@@ -154,7 +167,7 @@ class DecoderOnlyWrapper(nn.Module):
154
167
  f" or equal to max_seq_len({max_seq_len})!"
155
168
  )
156
169
 
157
- self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm)
170
+ self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm, max_seq_len)
158
171
 
159
172
  self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
160
173
  self._phase = "prefill"
@@ -162,21 +175,25 @@ class DecoderOnlyWrapper(nn.Module):
162
175
  def get_rotary_emb(self, max_seq_len):
163
176
  return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
164
177
 
165
- def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel):
178
+ def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
166
179
  new_layers = []
167
180
  for layer in causal_lm.model.layers:
168
181
  if self.attn_impl == "eager":
169
- new_self_attn = DecoderOnlyAttention(layer.self_attn)
182
+ new_self_attn = DecoderOnlyAttention(layer.self_attn, self.use_attention_mask)
170
183
  elif self.attn_impl == "flash_attn":
171
184
  new_self_attn = DecoderOnlyFlashAttention(
172
- layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
185
+ layer.self_attn,
186
+ kvcache_partition_len=self.kvcache_partition_len,
187
+ use_attention_mask=self.use_attention_mask,
173
188
  )
174
189
  else:
175
190
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
176
191
 
177
192
  new_layer = DecoderOnlyLayer(layer, new_self_attn)
178
193
  new_layers.append(new_layer)
179
- new_model = DecoderOnlyModel(causal_lm.model, new_layers, partition_len=self.kvcache_partition_len)
194
+ new_model = DecoderOnlyModel(
195
+ causal_lm.model, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
196
+ )
180
197
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
181
198
  return new_causal_lm
182
199
 
@@ -191,23 +208,42 @@ class DecoderOnlyWrapper(nn.Module):
191
208
 
192
209
  def forward(self, *args):
193
210
  if self.phase == "decode":
194
- (
195
- input_ids_or_inputs_embeds,
196
- attention_mask,
197
- cache_position,
198
- *past_key_values,
199
- ) = args
211
+ if self.use_attention_mask:
212
+ (
213
+ input_ids_or_inputs_embeds,
214
+ cache_position,
215
+ attention_mask,
216
+ *past_key_values,
217
+ ) = args
218
+ else:
219
+ (
220
+ input_ids_or_inputs_embeds,
221
+ cache_position,
222
+ *past_key_values,
223
+ ) = args
224
+ attention_mask = None
200
225
  batch_position = torch.tensor(0, dtype=torch.int16)
201
226
  query_position = None
202
227
  elif self.phase == "prefill":
203
- (
204
- input_ids_or_inputs_embeds,
205
- attention_mask,
206
- cache_position,
207
- batch_position,
208
- query_position,
209
- *past_key_values,
210
- ) = args
228
+ if self.use_attention_mask:
229
+ (
230
+ input_ids_or_inputs_embeds,
231
+ cache_position,
232
+ attention_mask,
233
+ batch_position,
234
+ query_position,
235
+ *past_key_values,
236
+ ) = args
237
+ else:
238
+ (
239
+ input_ids_or_inputs_embeds,
240
+ cache_position,
241
+ batch_position,
242
+ query_position,
243
+ *past_key_values,
244
+ ) = args
245
+ attention_mask = None
246
+
211
247
  else:
212
248
  raise ValueError(f"Unknown phase: {self.phase}")
213
249
 
@@ -338,12 +374,13 @@ class DecoderOnlyModel(nn.Module):
338
374
  _phase: Current processing phase ("prefill" or "decode")
339
375
  """
340
376
 
341
- def __init__(self, model, layers: List["DecoderOnlyLayer"], partition_len=None):
377
+ def __init__(self, model, layers: List["DecoderOnlyLayer"], partition_len=None, max_seq_len=None):
342
378
  super().__init__()
343
379
  self._original_mod = model
344
380
  self.layers = nn.ModuleList(layers)
345
381
  self._phase = "prefill"
346
382
  self.partition_len = partition_len
383
+ self.max_seq_len = max_seq_len
347
384
 
348
385
  @property
349
386
  def phase(self):
@@ -410,7 +447,7 @@ class DecoderOnlyModel(nn.Module):
410
447
 
411
448
  # get cos,sin vector if needed
412
449
  if rotary_emb is not None:
413
- cos, sin = rotary_emb(hidden_states, attention_mask.shape[-1]) # dtype carrier, max_seq_len
450
+ cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
414
451
  cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
415
452
  else:
416
453
  batch_size = inputs_embeds.shape[0]
@@ -542,7 +579,7 @@ class DecoderOnlyAttention(nn.Module):
542
579
  self_attn: Original attention module from the base model
543
580
  """
544
581
 
545
- def __init__(self, self_attn):
582
+ def __init__(self, self_attn, use_attention_mask):
546
583
  super().__init__()
547
584
  self._original_mod = self_attn
548
585
  self.layer_idx = self_attn.layer_idx
@@ -560,6 +597,7 @@ class DecoderOnlyAttention(nn.Module):
560
597
  else:
561
598
  self.num_key_value_heads = self.num_heads
562
599
 
600
+ self.use_attention_mask = use_attention_mask
563
601
  self.attention = self.get_attention()
564
602
  self.__post_init__()
565
603
 
@@ -573,7 +611,7 @@ class DecoderOnlyAttention(nn.Module):
573
611
  self.attention.phase = phase
574
612
 
575
613
  def get_attention(self):
576
- return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads)
614
+ return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask)
577
615
 
578
616
  def __post_init__(self):
579
617
  self.q_proj = self._original_mod.q_proj
@@ -648,12 +686,13 @@ class DecoderOnlyAttention(nn.Module):
648
686
 
649
687
 
650
688
  class AttentionOp(nn.Module):
651
- def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int):
689
+ def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool):
652
690
  super().__init__()
653
691
  self.num_heads = num_heads
654
692
  self.head_dim = head_dim
655
693
  self.num_key_value_heads = num_key_value_heads
656
694
  self.phase = "prefill"
695
+ self.use_attention_mask = use_attention_mask
657
696
 
658
697
  def forward(
659
698
  self,
@@ -686,7 +725,8 @@ class AttentionOp(nn.Module):
686
725
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
687
726
  key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
688
727
  value_state = value_state.unsqueeze(2)
689
- attn_mask = attn_mask.unsqueeze(2)
728
+ if self.use_attention_mask:
729
+ attn_mask = attn_mask.unsqueeze(2)
690
730
 
691
731
  if self.phase == "decode":
692
732
  batch_size = key_state.shape[0]
@@ -702,29 +742,52 @@ class AttentionOp(nn.Module):
702
742
  )
703
743
 
704
744
  if self.phase == "decode":
705
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.attn_decode(
706
- query_state,
707
- key_state,
708
- value_state,
709
- attn_mask,
710
- past_key_state.unsqueeze(2),
711
- past_value_state.unsqueeze(2),
712
- seq_position,
713
- scale,
714
- )
745
+ if self.use_attention_mask:
746
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.masked_attn_decode(
747
+ query_state,
748
+ key_state,
749
+ value_state,
750
+ attn_mask,
751
+ past_key_state.unsqueeze(2),
752
+ past_value_state.unsqueeze(2),
753
+ seq_position,
754
+ scale,
755
+ )
756
+ else:
757
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.causal_masked_attn_decode(
758
+ query_state,
759
+ key_state,
760
+ value_state,
761
+ past_key_state.unsqueeze(2),
762
+ past_value_state.unsqueeze(2),
763
+ seq_position,
764
+ scale,
765
+ )
715
766
 
716
767
  else:
717
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.attn_prefill(
718
- query_state,
719
- key_state,
720
- value_state,
721
- attn_mask,
722
- past_key_state.unsqueeze(2),
723
- past_value_state.unsqueeze(2),
724
- batch_position,
725
- seq_position,
726
- scale,
727
- )
768
+ if self.use_attention_mask:
769
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.masked_attn_prefill(
770
+ query_state,
771
+ key_state,
772
+ value_state,
773
+ attn_mask,
774
+ past_key_state.unsqueeze(2),
775
+ past_value_state.unsqueeze(2),
776
+ batch_position,
777
+ seq_position,
778
+ scale,
779
+ )
780
+ else:
781
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.causal_masked_attn_prefill(
782
+ query_state,
783
+ key_state,
784
+ value_state,
785
+ past_key_state.unsqueeze(2),
786
+ past_value_state.unsqueeze(2),
787
+ batch_position,
788
+ seq_position,
789
+ scale,
790
+ )
728
791
 
729
792
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
730
793
  attn_output = attn_output.transpose(1, 2).contiguous()
@@ -826,12 +889,19 @@ class RotaryEmbedding(nn.Module):
826
889
 
827
890
 
828
891
  class DecoderOnlyFlashAttention(DecoderOnlyAttention):
829
- def __init__(self, self_attn, kvcache_partition_len):
892
+ def __init__(self, self_attn, kvcache_partition_len, use_attention_mask):
830
893
  self.kvcache_partition_size = kvcache_partition_len
831
- super().__init__(self_attn=self_attn)
894
+ # self.use_attention_mask = use_attention_mask
895
+ super().__init__(self_attn=self_attn, use_attention_mask=use_attention_mask)
832
896
 
833
897
  def get_attention(self):
834
- return FlashAttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.kvcache_partition_size)
898
+ return FlashAttentionOp(
899
+ self.num_heads,
900
+ self.head_dim,
901
+ self.num_key_value_heads,
902
+ self.kvcache_partition_size,
903
+ self.use_attention_mask,
904
+ )
835
905
 
836
906
  def forward(
837
907
  self,
@@ -878,8 +948,20 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
878
948
 
879
949
 
880
950
  class FlashAttentionOp(AttentionOp):
881
- def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, kvcache_partition_len: int):
882
- super().__init__(num_heads=num_heads, head_dim=head_dim, num_key_value_heads=num_key_value_heads)
951
+ def __init__(
952
+ self,
953
+ num_heads: int,
954
+ head_dim: int,
955
+ num_key_value_heads: int,
956
+ kvcache_partition_len: int,
957
+ use_attention_mask: bool,
958
+ ):
959
+ super().__init__(
960
+ num_heads=num_heads,
961
+ head_dim=head_dim,
962
+ num_key_value_heads=num_key_value_heads,
963
+ use_attention_mask=use_attention_mask,
964
+ )
883
965
  self.kvcache_partition_size = kvcache_partition_len
884
966
 
885
967
  def forward(
@@ -897,7 +979,8 @@ class FlashAttentionOp(AttentionOp):
897
979
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
898
980
  key_state = key_state.unsqueeze(2)
899
981
  value_state = value_state.unsqueeze(2)
900
- attn_mask = attn_mask.unsqueeze(2)
982
+ if self.use_attention_mask:
983
+ attn_mask = attn_mask.unsqueeze(2)
901
984
 
902
985
  if self.phase == "decode":
903
986
  batch_size = key_state.shape[0]
@@ -913,30 +996,55 @@ class FlashAttentionOp(AttentionOp):
913
996
  )
914
997
 
915
998
  if self.phase == "decode":
916
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
917
- query_state,
918
- key_state,
919
- value_state,
920
- attn_mask,
921
- past_key_state.unsqueeze(2),
922
- past_value_state.unsqueeze(2),
923
- seq_position,
924
- scale,
925
- self.kvcache_partition_size,
926
- )
999
+ if self.use_attention_mask:
1000
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_masked_attn_decode(
1001
+ query_state,
1002
+ key_state,
1003
+ value_state,
1004
+ attn_mask,
1005
+ past_key_state.unsqueeze(2),
1006
+ past_value_state.unsqueeze(2),
1007
+ seq_position,
1008
+ scale,
1009
+ self.kvcache_partition_size,
1010
+ )
1011
+ else:
1012
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_causal_masked_attn_decode(
1013
+ query_state,
1014
+ key_state,
1015
+ value_state,
1016
+ past_key_state.unsqueeze(2),
1017
+ past_value_state.unsqueeze(2),
1018
+ seq_position,
1019
+ scale,
1020
+ self.kvcache_partition_size,
1021
+ )
927
1022
  else:
928
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_prefill(
929
- query_state,
930
- key_state,
931
- value_state,
932
- attn_mask,
933
- past_key_state.unsqueeze(2),
934
- past_value_state.unsqueeze(2),
935
- batch_position,
936
- seq_position,
937
- scale,
938
- self.kvcache_partition_size,
939
- )
1023
+ if self.use_attention_mask:
1024
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_masked_attn_prefill(
1025
+ query_state,
1026
+ key_state,
1027
+ value_state,
1028
+ attn_mask,
1029
+ past_key_state.unsqueeze(2),
1030
+ past_value_state.unsqueeze(2),
1031
+ batch_position,
1032
+ seq_position,
1033
+ scale,
1034
+ self.kvcache_partition_size,
1035
+ )
1036
+ else:
1037
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_causal_masked_attn_prefill(
1038
+ query_state,
1039
+ key_state,
1040
+ value_state,
1041
+ past_key_state.unsqueeze(2),
1042
+ past_value_state.unsqueeze(2),
1043
+ batch_position,
1044
+ seq_position,
1045
+ scale,
1046
+ self.kvcache_partition_size,
1047
+ )
940
1048
 
941
1049
  # reshape for removing repeat_kv
942
1050
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
@@ -50,12 +50,15 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
50
50
  phase: str,
51
51
  batch_size: int,
52
52
  dec_attn_mask: torch.Tensor,
53
+ use_attention_mask: bool,
53
54
  **kwargs: Any,
54
55
  ) -> None:
55
56
  super().__init__(runtime, **kwargs)
56
57
  self.phase = phase
57
58
  self.batch_size = batch_size
58
59
 
60
+ self.use_attention_mask = use_attention_mask
61
+
59
62
  # shared tensor between prefill and decode phase
60
63
  self.dec_attn_mask = dec_attn_mask
61
64
 
@@ -110,7 +113,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
110
113
  if batch_size != cache_position.shape[0]:
111
114
  raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
112
115
 
113
- if attention_mask is None:
116
+ if self.use_attention_mask and attention_mask is None:
114
117
  for b_idx in range(batch_size):
115
118
  decoding_step = cache_position[b_idx].item()
116
119
  if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
@@ -119,10 +122,12 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
119
122
  )
120
123
  self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
121
124
 
125
+ attention_mask = self.dec_attn_mask
126
+
122
127
  logits = super().forward(
123
128
  inputs,
124
- self.dec_attn_mask if attention_mask is None else attention_mask,
125
129
  cache_position,
130
+ attention_mask if self.use_attention_mask else None,
126
131
  )
127
132
 
128
133
  return logits
@@ -156,7 +161,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
156
161
  )
157
162
 
158
163
  # Initialize attention mask for chunked processing
159
- chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
164
+ if self.use_attention_mask:
165
+ chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
160
166
 
161
167
  # Buffer for storing output logits
162
168
  out_buffers = [
@@ -195,28 +201,41 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
195
201
  input_chunk = inputs[:, step : step + self.prefill_chunk_size]
196
202
  cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
197
203
 
198
- # Update attention mask to ensure proper causal behavior
199
- if step >= self.prefill_chunk_size:
200
- chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
201
- chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
204
+ if self.use_attention_mask:
205
+ # Update attention mask to ensure proper causal behavior
206
+ if step >= self.prefill_chunk_size:
207
+ chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
208
+ chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
202
209
 
203
210
  # Define batch position and query position
204
211
  batch_position = torch.tensor(batch_idx, dtype=torch.int16)
205
212
  query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
206
213
 
214
+ if self.use_attention_mask:
215
+ args = (
216
+ input_chunk,
217
+ cache_pos_chunk,
218
+ chunked_attention_mask,
219
+ batch_position,
220
+ query_position,
221
+ )
222
+ else:
223
+ args = (
224
+ input_chunk,
225
+ cache_pos_chunk,
226
+ batch_position,
227
+ query_position,
228
+ )
207
229
  # Forward pass for the current chunk
208
230
  logits = super().forward(
209
- input_chunk,
210
- chunked_attention_mask,
211
- cache_pos_chunk,
212
- batch_position,
213
- query_position,
231
+ *args,
214
232
  out=out_buffers,
215
233
  )
216
234
 
217
- # Update decoder attention mask with processed KV-cache length from prefill phase
218
- self.dec_attn_mask[batch_idx].fill_(0)
219
- self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
235
+ if self.use_attention_mask:
236
+ # Update decoder attention mask with processed KV-cache length from prefill phase
237
+ self.dec_attn_mask[batch_idx].fill_(0)
238
+ self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
220
239
 
221
240
  return logits
222
241
 
@@ -256,6 +275,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
256
275
  self.batch_size = self.rbln_config.model_cfg["batch_size"]
257
276
  self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
258
277
  self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
278
+ self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
259
279
 
260
280
  main_input_name = self.main_input_name
261
281
  if self.rbln_config.model_cfg["use_inputs_embeds"]:
@@ -282,6 +302,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
282
302
  vocab_size=self.config.vocab_size,
283
303
  max_seq_len=self.max_seq_len,
284
304
  prefill_chunk_size=self.prefill_chunk_size,
305
+ use_attention_mask=self.use_attention_mask,
285
306
  )
286
307
  self.decoder = RBLNRuntimeModel(
287
308
  runtime=self.model[1],
@@ -290,6 +311,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
290
311
  phase="decode",
291
312
  batch_size=self.batch_size,
292
313
  dec_attn_mask=dec_attn_mask,
314
+ use_attention_mask=self.use_attention_mask,
293
315
  )
294
316
 
295
317
  @classmethod
@@ -363,7 +385,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
363
385
  def redirect(func):
364
386
  return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
365
387
 
366
- val = getattr(self.hf_class, __name, None) or getattr(PreTrainedModel, __name)
388
+ val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
367
389
  if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
368
390
  return redirect(val)
369
391
  return val
@@ -388,6 +410,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
388
410
  wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
389
411
  wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
390
412
  wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
413
+ wrapper_cfg["use_attention_mask"] = rbln_config.model_cfg.get("use_attention_mask")
391
414
 
392
415
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
393
416
 
@@ -448,11 +471,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
448
471
  rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
449
472
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
450
473
  rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
474
+ rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
451
475
  rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
452
476
  rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
453
477
  rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
454
478
  rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
455
479
 
480
+ if rbln_use_attention_mask is None:
481
+ rbln_use_attention_mask = False
482
+ rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
483
+ if rbln_npu == "RBLN-CA02":
484
+ rbln_use_attention_mask = True
485
+
456
486
  if rbln_prefill_chunk_size is None:
457
487
  rbln_prefill_chunk_size = 128
458
488
  elif rbln_prefill_chunk_size % 64 != 0 or rbln_prefill_chunk_size == 0:
@@ -495,13 +525,20 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
495
525
 
496
526
  input_info = [
497
527
  main_input,
498
- ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
499
528
  (
500
529
  "cache_position",
501
530
  [batch_size, query_length],
502
531
  "int32",
503
532
  ),
504
533
  ]
534
+
535
+ if rbln_use_attention_mask:
536
+ input_info.extend(
537
+ [
538
+ ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
539
+ ]
540
+ )
541
+
505
542
  if query_length > 1:
506
543
  input_info.extend(
507
544
  [
@@ -555,6 +592,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
555
592
  "max_seq_len": rbln_max_seq_len,
556
593
  "batch_size": rbln_batch_size,
557
594
  "prefill_chunk_size": rbln_prefill_chunk_size,
595
+ "use_attention_mask": rbln_use_attention_mask,
558
596
  "use_inputs_embeds": rbln_use_inputs_embeds,
559
597
  "kvcache_partition_len": rbln_kvcache_partition_len,
560
598
  "attn_impl": rbln_attn_impl,
@@ -36,11 +36,11 @@ logger = logging.get_logger(__name__)
36
36
  class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
37
37
  """A wrapper class for the Exaone model with a language modeling head."""
38
38
 
39
- def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM"):
39
+ def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM", max_seq_len: int):
40
40
  new_layers = []
41
41
  for layer in causal_lm.transformer.h:
42
42
  if self.attn_impl == "eager":
43
- new_self_attn = ExaoneAttention(layer.attn.attention)
43
+ new_self_attn = ExaoneAttention(layer.attn.attention, self.use_attention_mask)
44
44
  elif self.attn_impl == "flash_attn":
45
45
  new_self_attn = ExaoneFlashAttention(
46
46
  layer.attn.attention, kvcache_partition_len=self.kvcache_partition_len
@@ -50,7 +50,9 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
50
50
 
51
51
  new_layer = ExaoneLayer(layer, new_self_attn)
52
52
  new_layers.append(new_layer)
53
- new_model = ExaoneModel(causal_lm.transformer, new_layers, partition_len=self.kvcache_partition_len)
53
+ new_model = ExaoneModel(
54
+ causal_lm.transformer, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
55
+ )
54
56
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
55
57
  return new_causal_lm
56
58
 
@@ -29,11 +29,11 @@ if TYPE_CHECKING:
29
29
 
30
30
 
31
31
  class GemmaWrapper(DecoderOnlyWrapper):
32
- def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM"):
32
+ def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM", max_seq_len: int):
33
33
  new_layers = []
34
34
  for layer in causal_lm.model.layers:
35
35
  if self.attn_impl == "eager":
36
- new_self_attn = DecoderOnlyAttention(layer.self_attn)
36
+ new_self_attn = DecoderOnlyAttention(layer.self_attn, self.use_attention_mask)
37
37
  elif self.attn_impl == "flash_attn":
38
38
  new_self_attn = DecoderOnlyFlashAttention(
39
39
  layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
@@ -42,7 +42,9 @@ class GemmaWrapper(DecoderOnlyWrapper):
42
42
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
43
43
  new_layer = DecoderOnlyLayer(layer, new_self_attn)
44
44
  new_layers.append(new_layer)
45
- new_model = GemmaModel(causal_lm.model, new_layers, partition_len=self.kvcache_partition_len)
45
+ new_model = GemmaModel(
46
+ causal_lm.model, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
47
+ )
46
48
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
47
49
  return new_causal_lm
48
50