optimum-rbln 0.7.2rc2__py3-none-any.whl → 0.7.3a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/modeling_diffusers.py +4 -6
- optimum/rbln/modeling.py +1 -1
- optimum/rbln/modeling_base.py +15 -3
- optimum/rbln/ops/__init__.py +6 -2
- optimum/rbln/ops/attn.py +95 -7
- optimum/rbln/ops/flash_attn.py +43 -6
- optimum/rbln/transformers/modeling_generic.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -1
- optimum/rbln/transformers/models/bart/modeling_bart.py +1 -1
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +186 -78
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +55 -17
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -3
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +3 -3
- optimum/rbln/transformers/models/midm/midm_architecture.py +3 -3
- optimum/rbln/transformers/models/phi/phi_architecture.py +2 -2
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/t5/modeling_t5.py +1 -1
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -14
- optimum/rbln/utils/import_utils.py +7 -0
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3a1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3a1.dist-info}/RECORD +26 -26
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3a1.dist-info}/licenses/LICENSE +0 -0
@@ -19,7 +19,12 @@ import torch
|
|
19
19
|
from torch import nn
|
20
20
|
from transformers import PretrainedConfig, PreTrainedModel
|
21
21
|
|
22
|
-
from ....ops import
|
22
|
+
from ....ops import (
|
23
|
+
register_rbln_custom_causal_masked_attention,
|
24
|
+
register_rbln_custom_flash_causal_masked_attention,
|
25
|
+
register_rbln_custom_flash_masked_attention,
|
26
|
+
register_rbln_custom_masked_attention,
|
27
|
+
)
|
23
28
|
from ....utils import logging
|
24
29
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
25
30
|
|
@@ -128,6 +133,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
128
133
|
max_seq_len: int,
|
129
134
|
use_rotary_emb: bool,
|
130
135
|
attn_impl: str,
|
136
|
+
use_attention_mask: bool,
|
131
137
|
kvcache_partition_len: Optional[int] = None,
|
132
138
|
):
|
133
139
|
super().__init__()
|
@@ -139,12 +145,19 @@ class DecoderOnlyWrapper(nn.Module):
|
|
139
145
|
self.rotary_emb = None
|
140
146
|
|
141
147
|
self.attn_impl = attn_impl
|
148
|
+
self.use_attention_mask = use_attention_mask
|
142
149
|
if self.attn_impl == "flash_attn":
|
143
150
|
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
144
|
-
|
151
|
+
if self.use_attention_mask:
|
152
|
+
register_rbln_custom_flash_masked_attention()
|
153
|
+
else:
|
154
|
+
register_rbln_custom_flash_causal_masked_attention()
|
145
155
|
elif self.attn_impl == "eager":
|
146
156
|
self.kvcache_partition_len = None
|
147
|
-
|
157
|
+
if self.use_attention_mask:
|
158
|
+
register_rbln_custom_masked_attention()
|
159
|
+
else:
|
160
|
+
register_rbln_custom_causal_masked_attention()
|
148
161
|
else:
|
149
162
|
raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
|
150
163
|
|
@@ -154,7 +167,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
154
167
|
f" or equal to max_seq_len({max_seq_len})!"
|
155
168
|
)
|
156
169
|
|
157
|
-
self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm)
|
170
|
+
self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm, max_seq_len)
|
158
171
|
|
159
172
|
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
|
160
173
|
self._phase = "prefill"
|
@@ -162,21 +175,25 @@ class DecoderOnlyWrapper(nn.Module):
|
|
162
175
|
def get_rotary_emb(self, max_seq_len):
|
163
176
|
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
164
177
|
|
165
|
-
def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel):
|
178
|
+
def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
|
166
179
|
new_layers = []
|
167
180
|
for layer in causal_lm.model.layers:
|
168
181
|
if self.attn_impl == "eager":
|
169
|
-
new_self_attn = DecoderOnlyAttention(layer.self_attn)
|
182
|
+
new_self_attn = DecoderOnlyAttention(layer.self_attn, self.use_attention_mask)
|
170
183
|
elif self.attn_impl == "flash_attn":
|
171
184
|
new_self_attn = DecoderOnlyFlashAttention(
|
172
|
-
layer.self_attn,
|
185
|
+
layer.self_attn,
|
186
|
+
kvcache_partition_len=self.kvcache_partition_len,
|
187
|
+
use_attention_mask=self.use_attention_mask,
|
173
188
|
)
|
174
189
|
else:
|
175
190
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
176
191
|
|
177
192
|
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
178
193
|
new_layers.append(new_layer)
|
179
|
-
new_model = DecoderOnlyModel(
|
194
|
+
new_model = DecoderOnlyModel(
|
195
|
+
causal_lm.model, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
|
196
|
+
)
|
180
197
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
181
198
|
return new_causal_lm
|
182
199
|
|
@@ -191,23 +208,42 @@ class DecoderOnlyWrapper(nn.Module):
|
|
191
208
|
|
192
209
|
def forward(self, *args):
|
193
210
|
if self.phase == "decode":
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
211
|
+
if self.use_attention_mask:
|
212
|
+
(
|
213
|
+
input_ids_or_inputs_embeds,
|
214
|
+
cache_position,
|
215
|
+
attention_mask,
|
216
|
+
*past_key_values,
|
217
|
+
) = args
|
218
|
+
else:
|
219
|
+
(
|
220
|
+
input_ids_or_inputs_embeds,
|
221
|
+
cache_position,
|
222
|
+
*past_key_values,
|
223
|
+
) = args
|
224
|
+
attention_mask = None
|
200
225
|
batch_position = torch.tensor(0, dtype=torch.int16)
|
201
226
|
query_position = None
|
202
227
|
elif self.phase == "prefill":
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
228
|
+
if self.use_attention_mask:
|
229
|
+
(
|
230
|
+
input_ids_or_inputs_embeds,
|
231
|
+
cache_position,
|
232
|
+
attention_mask,
|
233
|
+
batch_position,
|
234
|
+
query_position,
|
235
|
+
*past_key_values,
|
236
|
+
) = args
|
237
|
+
else:
|
238
|
+
(
|
239
|
+
input_ids_or_inputs_embeds,
|
240
|
+
cache_position,
|
241
|
+
batch_position,
|
242
|
+
query_position,
|
243
|
+
*past_key_values,
|
244
|
+
) = args
|
245
|
+
attention_mask = None
|
246
|
+
|
211
247
|
else:
|
212
248
|
raise ValueError(f"Unknown phase: {self.phase}")
|
213
249
|
|
@@ -338,12 +374,13 @@ class DecoderOnlyModel(nn.Module):
|
|
338
374
|
_phase: Current processing phase ("prefill" or "decode")
|
339
375
|
"""
|
340
376
|
|
341
|
-
def __init__(self, model, layers: List["DecoderOnlyLayer"], partition_len=None):
|
377
|
+
def __init__(self, model, layers: List["DecoderOnlyLayer"], partition_len=None, max_seq_len=None):
|
342
378
|
super().__init__()
|
343
379
|
self._original_mod = model
|
344
380
|
self.layers = nn.ModuleList(layers)
|
345
381
|
self._phase = "prefill"
|
346
382
|
self.partition_len = partition_len
|
383
|
+
self.max_seq_len = max_seq_len
|
347
384
|
|
348
385
|
@property
|
349
386
|
def phase(self):
|
@@ -410,7 +447,7 @@ class DecoderOnlyModel(nn.Module):
|
|
410
447
|
|
411
448
|
# get cos,sin vector if needed
|
412
449
|
if rotary_emb is not None:
|
413
|
-
cos, sin = rotary_emb(hidden_states,
|
450
|
+
cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
|
414
451
|
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
|
415
452
|
else:
|
416
453
|
batch_size = inputs_embeds.shape[0]
|
@@ -542,7 +579,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
542
579
|
self_attn: Original attention module from the base model
|
543
580
|
"""
|
544
581
|
|
545
|
-
def __init__(self, self_attn):
|
582
|
+
def __init__(self, self_attn, use_attention_mask):
|
546
583
|
super().__init__()
|
547
584
|
self._original_mod = self_attn
|
548
585
|
self.layer_idx = self_attn.layer_idx
|
@@ -560,6 +597,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
560
597
|
else:
|
561
598
|
self.num_key_value_heads = self.num_heads
|
562
599
|
|
600
|
+
self.use_attention_mask = use_attention_mask
|
563
601
|
self.attention = self.get_attention()
|
564
602
|
self.__post_init__()
|
565
603
|
|
@@ -573,7 +611,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
573
611
|
self.attention.phase = phase
|
574
612
|
|
575
613
|
def get_attention(self):
|
576
|
-
return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads)
|
614
|
+
return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask)
|
577
615
|
|
578
616
|
def __post_init__(self):
|
579
617
|
self.q_proj = self._original_mod.q_proj
|
@@ -648,12 +686,13 @@ class DecoderOnlyAttention(nn.Module):
|
|
648
686
|
|
649
687
|
|
650
688
|
class AttentionOp(nn.Module):
|
651
|
-
def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int):
|
689
|
+
def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool):
|
652
690
|
super().__init__()
|
653
691
|
self.num_heads = num_heads
|
654
692
|
self.head_dim = head_dim
|
655
693
|
self.num_key_value_heads = num_key_value_heads
|
656
694
|
self.phase = "prefill"
|
695
|
+
self.use_attention_mask = use_attention_mask
|
657
696
|
|
658
697
|
def forward(
|
659
698
|
self,
|
@@ -686,7 +725,8 @@ class AttentionOp(nn.Module):
|
|
686
725
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
687
726
|
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
688
727
|
value_state = value_state.unsqueeze(2)
|
689
|
-
|
728
|
+
if self.use_attention_mask:
|
729
|
+
attn_mask = attn_mask.unsqueeze(2)
|
690
730
|
|
691
731
|
if self.phase == "decode":
|
692
732
|
batch_size = key_state.shape[0]
|
@@ -702,29 +742,52 @@ class AttentionOp(nn.Module):
|
|
702
742
|
)
|
703
743
|
|
704
744
|
if self.phase == "decode":
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
745
|
+
if self.use_attention_mask:
|
746
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.masked_attn_decode(
|
747
|
+
query_state,
|
748
|
+
key_state,
|
749
|
+
value_state,
|
750
|
+
attn_mask,
|
751
|
+
past_key_state.unsqueeze(2),
|
752
|
+
past_value_state.unsqueeze(2),
|
753
|
+
seq_position,
|
754
|
+
scale,
|
755
|
+
)
|
756
|
+
else:
|
757
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.causal_masked_attn_decode(
|
758
|
+
query_state,
|
759
|
+
key_state,
|
760
|
+
value_state,
|
761
|
+
past_key_state.unsqueeze(2),
|
762
|
+
past_value_state.unsqueeze(2),
|
763
|
+
seq_position,
|
764
|
+
scale,
|
765
|
+
)
|
715
766
|
|
716
767
|
else:
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
768
|
+
if self.use_attention_mask:
|
769
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.masked_attn_prefill(
|
770
|
+
query_state,
|
771
|
+
key_state,
|
772
|
+
value_state,
|
773
|
+
attn_mask,
|
774
|
+
past_key_state.unsqueeze(2),
|
775
|
+
past_value_state.unsqueeze(2),
|
776
|
+
batch_position,
|
777
|
+
seq_position,
|
778
|
+
scale,
|
779
|
+
)
|
780
|
+
else:
|
781
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.causal_masked_attn_prefill(
|
782
|
+
query_state,
|
783
|
+
key_state,
|
784
|
+
value_state,
|
785
|
+
past_key_state.unsqueeze(2),
|
786
|
+
past_value_state.unsqueeze(2),
|
787
|
+
batch_position,
|
788
|
+
seq_position,
|
789
|
+
scale,
|
790
|
+
)
|
728
791
|
|
729
792
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
730
793
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
@@ -826,12 +889,19 @@ class RotaryEmbedding(nn.Module):
|
|
826
889
|
|
827
890
|
|
828
891
|
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
829
|
-
def __init__(self, self_attn, kvcache_partition_len):
|
892
|
+
def __init__(self, self_attn, kvcache_partition_len, use_attention_mask):
|
830
893
|
self.kvcache_partition_size = kvcache_partition_len
|
831
|
-
|
894
|
+
# self.use_attention_mask = use_attention_mask
|
895
|
+
super().__init__(self_attn=self_attn, use_attention_mask=use_attention_mask)
|
832
896
|
|
833
897
|
def get_attention(self):
|
834
|
-
return FlashAttentionOp(
|
898
|
+
return FlashAttentionOp(
|
899
|
+
self.num_heads,
|
900
|
+
self.head_dim,
|
901
|
+
self.num_key_value_heads,
|
902
|
+
self.kvcache_partition_size,
|
903
|
+
self.use_attention_mask,
|
904
|
+
)
|
835
905
|
|
836
906
|
def forward(
|
837
907
|
self,
|
@@ -878,8 +948,20 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
878
948
|
|
879
949
|
|
880
950
|
class FlashAttentionOp(AttentionOp):
|
881
|
-
def __init__(
|
882
|
-
|
951
|
+
def __init__(
|
952
|
+
self,
|
953
|
+
num_heads: int,
|
954
|
+
head_dim: int,
|
955
|
+
num_key_value_heads: int,
|
956
|
+
kvcache_partition_len: int,
|
957
|
+
use_attention_mask: bool,
|
958
|
+
):
|
959
|
+
super().__init__(
|
960
|
+
num_heads=num_heads,
|
961
|
+
head_dim=head_dim,
|
962
|
+
num_key_value_heads=num_key_value_heads,
|
963
|
+
use_attention_mask=use_attention_mask,
|
964
|
+
)
|
883
965
|
self.kvcache_partition_size = kvcache_partition_len
|
884
966
|
|
885
967
|
def forward(
|
@@ -897,7 +979,8 @@ class FlashAttentionOp(AttentionOp):
|
|
897
979
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
898
980
|
key_state = key_state.unsqueeze(2)
|
899
981
|
value_state = value_state.unsqueeze(2)
|
900
|
-
|
982
|
+
if self.use_attention_mask:
|
983
|
+
attn_mask = attn_mask.unsqueeze(2)
|
901
984
|
|
902
985
|
if self.phase == "decode":
|
903
986
|
batch_size = key_state.shape[0]
|
@@ -913,30 +996,55 @@ class FlashAttentionOp(AttentionOp):
|
|
913
996
|
)
|
914
997
|
|
915
998
|
if self.phase == "decode":
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
999
|
+
if self.use_attention_mask:
|
1000
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_masked_attn_decode(
|
1001
|
+
query_state,
|
1002
|
+
key_state,
|
1003
|
+
value_state,
|
1004
|
+
attn_mask,
|
1005
|
+
past_key_state.unsqueeze(2),
|
1006
|
+
past_value_state.unsqueeze(2),
|
1007
|
+
seq_position,
|
1008
|
+
scale,
|
1009
|
+
self.kvcache_partition_size,
|
1010
|
+
)
|
1011
|
+
else:
|
1012
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_causal_masked_attn_decode(
|
1013
|
+
query_state,
|
1014
|
+
key_state,
|
1015
|
+
value_state,
|
1016
|
+
past_key_state.unsqueeze(2),
|
1017
|
+
past_value_state.unsqueeze(2),
|
1018
|
+
seq_position,
|
1019
|
+
scale,
|
1020
|
+
self.kvcache_partition_size,
|
1021
|
+
)
|
927
1022
|
else:
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
1023
|
+
if self.use_attention_mask:
|
1024
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_masked_attn_prefill(
|
1025
|
+
query_state,
|
1026
|
+
key_state,
|
1027
|
+
value_state,
|
1028
|
+
attn_mask,
|
1029
|
+
past_key_state.unsqueeze(2),
|
1030
|
+
past_value_state.unsqueeze(2),
|
1031
|
+
batch_position,
|
1032
|
+
seq_position,
|
1033
|
+
scale,
|
1034
|
+
self.kvcache_partition_size,
|
1035
|
+
)
|
1036
|
+
else:
|
1037
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_causal_masked_attn_prefill(
|
1038
|
+
query_state,
|
1039
|
+
key_state,
|
1040
|
+
value_state,
|
1041
|
+
past_key_state.unsqueeze(2),
|
1042
|
+
past_value_state.unsqueeze(2),
|
1043
|
+
batch_position,
|
1044
|
+
seq_position,
|
1045
|
+
scale,
|
1046
|
+
self.kvcache_partition_size,
|
1047
|
+
)
|
940
1048
|
|
941
1049
|
# reshape for removing repeat_kv
|
942
1050
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
@@ -50,12 +50,15 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
50
50
|
phase: str,
|
51
51
|
batch_size: int,
|
52
52
|
dec_attn_mask: torch.Tensor,
|
53
|
+
use_attention_mask: bool,
|
53
54
|
**kwargs: Any,
|
54
55
|
) -> None:
|
55
56
|
super().__init__(runtime, **kwargs)
|
56
57
|
self.phase = phase
|
57
58
|
self.batch_size = batch_size
|
58
59
|
|
60
|
+
self.use_attention_mask = use_attention_mask
|
61
|
+
|
59
62
|
# shared tensor between prefill and decode phase
|
60
63
|
self.dec_attn_mask = dec_attn_mask
|
61
64
|
|
@@ -110,7 +113,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
110
113
|
if batch_size != cache_position.shape[0]:
|
111
114
|
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
112
115
|
|
113
|
-
if attention_mask is None:
|
116
|
+
if self.use_attention_mask and attention_mask is None:
|
114
117
|
for b_idx in range(batch_size):
|
115
118
|
decoding_step = cache_position[b_idx].item()
|
116
119
|
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
@@ -119,10 +122,12 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
119
122
|
)
|
120
123
|
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
121
124
|
|
125
|
+
attention_mask = self.dec_attn_mask
|
126
|
+
|
122
127
|
logits = super().forward(
|
123
128
|
inputs,
|
124
|
-
self.dec_attn_mask if attention_mask is None else attention_mask,
|
125
129
|
cache_position,
|
130
|
+
attention_mask if self.use_attention_mask else None,
|
126
131
|
)
|
127
132
|
|
128
133
|
return logits
|
@@ -156,7 +161,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
156
161
|
)
|
157
162
|
|
158
163
|
# Initialize attention mask for chunked processing
|
159
|
-
|
164
|
+
if self.use_attention_mask:
|
165
|
+
chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
160
166
|
|
161
167
|
# Buffer for storing output logits
|
162
168
|
out_buffers = [
|
@@ -195,28 +201,41 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
195
201
|
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
196
202
|
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
197
203
|
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
204
|
+
if self.use_attention_mask:
|
205
|
+
# Update attention mask to ensure proper causal behavior
|
206
|
+
if step >= self.prefill_chunk_size:
|
207
|
+
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
208
|
+
chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
202
209
|
|
203
210
|
# Define batch position and query position
|
204
211
|
batch_position = torch.tensor(batch_idx, dtype=torch.int16)
|
205
212
|
query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
|
206
213
|
|
214
|
+
if self.use_attention_mask:
|
215
|
+
args = (
|
216
|
+
input_chunk,
|
217
|
+
cache_pos_chunk,
|
218
|
+
chunked_attention_mask,
|
219
|
+
batch_position,
|
220
|
+
query_position,
|
221
|
+
)
|
222
|
+
else:
|
223
|
+
args = (
|
224
|
+
input_chunk,
|
225
|
+
cache_pos_chunk,
|
226
|
+
batch_position,
|
227
|
+
query_position,
|
228
|
+
)
|
207
229
|
# Forward pass for the current chunk
|
208
230
|
logits = super().forward(
|
209
|
-
|
210
|
-
chunked_attention_mask,
|
211
|
-
cache_pos_chunk,
|
212
|
-
batch_position,
|
213
|
-
query_position,
|
231
|
+
*args,
|
214
232
|
out=out_buffers,
|
215
233
|
)
|
216
234
|
|
217
|
-
|
218
|
-
|
219
|
-
|
235
|
+
if self.use_attention_mask:
|
236
|
+
# Update decoder attention mask with processed KV-cache length from prefill phase
|
237
|
+
self.dec_attn_mask[batch_idx].fill_(0)
|
238
|
+
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
220
239
|
|
221
240
|
return logits
|
222
241
|
|
@@ -256,6 +275,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
256
275
|
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
257
276
|
self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
|
258
277
|
self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
|
278
|
+
self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
|
259
279
|
|
260
280
|
main_input_name = self.main_input_name
|
261
281
|
if self.rbln_config.model_cfg["use_inputs_embeds"]:
|
@@ -282,6 +302,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
282
302
|
vocab_size=self.config.vocab_size,
|
283
303
|
max_seq_len=self.max_seq_len,
|
284
304
|
prefill_chunk_size=self.prefill_chunk_size,
|
305
|
+
use_attention_mask=self.use_attention_mask,
|
285
306
|
)
|
286
307
|
self.decoder = RBLNRuntimeModel(
|
287
308
|
runtime=self.model[1],
|
@@ -290,6 +311,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
290
311
|
phase="decode",
|
291
312
|
batch_size=self.batch_size,
|
292
313
|
dec_attn_mask=dec_attn_mask,
|
314
|
+
use_attention_mask=self.use_attention_mask,
|
293
315
|
)
|
294
316
|
|
295
317
|
@classmethod
|
@@ -363,7 +385,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
363
385
|
def redirect(func):
|
364
386
|
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
365
387
|
|
366
|
-
val = getattr(self.
|
388
|
+
val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
|
367
389
|
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
368
390
|
return redirect(val)
|
369
391
|
return val
|
@@ -388,6 +410,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
388
410
|
wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
|
389
411
|
wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
|
390
412
|
wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
|
413
|
+
wrapper_cfg["use_attention_mask"] = rbln_config.model_cfg.get("use_attention_mask")
|
391
414
|
|
392
415
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
393
416
|
|
@@ -448,11 +471,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
448
471
|
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
449
472
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
450
473
|
rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
|
474
|
+
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
451
475
|
rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
|
452
476
|
rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
|
453
477
|
rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
|
454
478
|
rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
|
455
479
|
|
480
|
+
if rbln_use_attention_mask is None:
|
481
|
+
rbln_use_attention_mask = False
|
482
|
+
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
483
|
+
if rbln_npu == "RBLN-CA02":
|
484
|
+
rbln_use_attention_mask = True
|
485
|
+
|
456
486
|
if rbln_prefill_chunk_size is None:
|
457
487
|
rbln_prefill_chunk_size = 128
|
458
488
|
elif rbln_prefill_chunk_size % 64 != 0 or rbln_prefill_chunk_size == 0:
|
@@ -495,13 +525,20 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
495
525
|
|
496
526
|
input_info = [
|
497
527
|
main_input,
|
498
|
-
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
|
499
528
|
(
|
500
529
|
"cache_position",
|
501
530
|
[batch_size, query_length],
|
502
531
|
"int32",
|
503
532
|
),
|
504
533
|
]
|
534
|
+
|
535
|
+
if rbln_use_attention_mask:
|
536
|
+
input_info.extend(
|
537
|
+
[
|
538
|
+
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
|
539
|
+
]
|
540
|
+
)
|
541
|
+
|
505
542
|
if query_length > 1:
|
506
543
|
input_info.extend(
|
507
544
|
[
|
@@ -555,6 +592,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
555
592
|
"max_seq_len": rbln_max_seq_len,
|
556
593
|
"batch_size": rbln_batch_size,
|
557
594
|
"prefill_chunk_size": rbln_prefill_chunk_size,
|
595
|
+
"use_attention_mask": rbln_use_attention_mask,
|
558
596
|
"use_inputs_embeds": rbln_use_inputs_embeds,
|
559
597
|
"kvcache_partition_len": rbln_kvcache_partition_len,
|
560
598
|
"attn_impl": rbln_attn_impl,
|
@@ -36,11 +36,11 @@ logger = logging.get_logger(__name__)
|
|
36
36
|
class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
37
37
|
"""A wrapper class for the Exaone model with a language modeling head."""
|
38
38
|
|
39
|
-
def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM"):
|
39
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM", max_seq_len: int):
|
40
40
|
new_layers = []
|
41
41
|
for layer in causal_lm.transformer.h:
|
42
42
|
if self.attn_impl == "eager":
|
43
|
-
new_self_attn = ExaoneAttention(layer.attn.attention)
|
43
|
+
new_self_attn = ExaoneAttention(layer.attn.attention, self.use_attention_mask)
|
44
44
|
elif self.attn_impl == "flash_attn":
|
45
45
|
new_self_attn = ExaoneFlashAttention(
|
46
46
|
layer.attn.attention, kvcache_partition_len=self.kvcache_partition_len
|
@@ -50,7 +50,9 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
50
50
|
|
51
51
|
new_layer = ExaoneLayer(layer, new_self_attn)
|
52
52
|
new_layers.append(new_layer)
|
53
|
-
new_model = ExaoneModel(
|
53
|
+
new_model = ExaoneModel(
|
54
|
+
causal_lm.transformer, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
|
55
|
+
)
|
54
56
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
55
57
|
return new_causal_lm
|
56
58
|
|
@@ -29,11 +29,11 @@ if TYPE_CHECKING:
|
|
29
29
|
|
30
30
|
|
31
31
|
class GemmaWrapper(DecoderOnlyWrapper):
|
32
|
-
def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM"):
|
32
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM", max_seq_len: int):
|
33
33
|
new_layers = []
|
34
34
|
for layer in causal_lm.model.layers:
|
35
35
|
if self.attn_impl == "eager":
|
36
|
-
new_self_attn = DecoderOnlyAttention(layer.self_attn)
|
36
|
+
new_self_attn = DecoderOnlyAttention(layer.self_attn, self.use_attention_mask)
|
37
37
|
elif self.attn_impl == "flash_attn":
|
38
38
|
new_self_attn = DecoderOnlyFlashAttention(
|
39
39
|
layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
|
@@ -42,7 +42,9 @@ class GemmaWrapper(DecoderOnlyWrapper):
|
|
42
42
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
43
43
|
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
44
44
|
new_layers.append(new_layer)
|
45
|
-
new_model = GemmaModel(
|
45
|
+
new_model = GemmaModel(
|
46
|
+
causal_lm.model, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
|
47
|
+
)
|
46
48
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
47
49
|
return new_causal_lm
|
48
50
|
|