optimum-rbln 0.8.1rc1__py3-none-any.whl → 0.8.2a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +210 -257
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +2 -40
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +18 -22
- optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
- optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
- optimum/rbln/transformers/models/phi/phi_architecture.py +14 -20
- {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2a1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2a1.dist-info}/RECORD +19 -19
- {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2a1.dist-info}/licenses/LICENSE +0 -0
|
@@ -149,6 +149,8 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
149
149
|
This is only relevant if `attn_impl` is set to "flash_attn`
|
|
150
150
|
"""
|
|
151
151
|
|
|
152
|
+
_use_learned_pos_emb = False
|
|
153
|
+
|
|
152
154
|
def __init__(
|
|
153
155
|
self,
|
|
154
156
|
causal_lm: PreTrainedModel,
|
|
@@ -159,7 +161,6 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
159
161
|
use_inputs_embeds: bool,
|
|
160
162
|
use_attention_mask: bool,
|
|
161
163
|
use_position_ids: bool,
|
|
162
|
-
use_learned_pos_emb: Optional[bool] = None,
|
|
163
164
|
kvcache_partition_len: Optional[int] = None,
|
|
164
165
|
kvcache_block_size: Optional[int] = None,
|
|
165
166
|
sliding_window: Optional[int] = None,
|
|
@@ -182,7 +183,6 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
182
183
|
self.use_attention_mask = use_attention_mask
|
|
183
184
|
self.use_position_ids = use_position_ids
|
|
184
185
|
self.use_inputs_embeds = use_inputs_embeds
|
|
185
|
-
self.use_learned_pos_emb = use_learned_pos_emb
|
|
186
186
|
self.sliding_window_layers = sliding_window_layers
|
|
187
187
|
self.cache_impl = cache_impl
|
|
188
188
|
self.sliding_window = sliding_window
|
|
@@ -207,51 +207,55 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
207
207
|
def get_rotary_emb(self, max_seq_len):
|
|
208
208
|
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
|
209
209
|
|
|
210
|
+
def get_decoder_layers(self, causal_lm: PreTrainedModel):
|
|
211
|
+
return causal_lm.model.layers
|
|
212
|
+
|
|
213
|
+
def get_attn_layer(self, layer: nn.Module):
|
|
214
|
+
return layer.self_attn
|
|
215
|
+
|
|
216
|
+
def get_model_layer(self, causal_lm: PreTrainedModel):
|
|
217
|
+
return causal_lm.model
|
|
218
|
+
|
|
219
|
+
def get_rbln_attn_class(self):
|
|
220
|
+
return DecoderOnlyAttention
|
|
221
|
+
|
|
222
|
+
def get_rbln_layer_class(self):
|
|
223
|
+
return DecoderOnlyLayer
|
|
224
|
+
|
|
225
|
+
def get_rbln_model_class(self):
|
|
226
|
+
return DecoderOnlyModel
|
|
227
|
+
|
|
228
|
+
def get_rbln_causal_lm_class(self):
|
|
229
|
+
return DecoderOnlyForCausalLM
|
|
230
|
+
|
|
210
231
|
def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
|
|
211
232
|
new_layers = []
|
|
212
|
-
for layer_idx, layer in enumerate(causal_lm
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
self.use_attention_mask,
|
|
227
|
-
self.use_position_ids,
|
|
228
|
-
kvcache_block_size=self.kvcache_block_size,
|
|
229
|
-
is_sliding=False,
|
|
230
|
-
)
|
|
231
|
-
elif self.attn_impl == "flash_attn":
|
|
232
|
-
new_self_attn = DecoderOnlyFlashAttention(
|
|
233
|
-
layer.self_attn,
|
|
234
|
-
kvcache_partition_len=self.kvcache_partition_len,
|
|
235
|
-
kvcache_block_size=self.kvcache_block_size,
|
|
236
|
-
use_attention_mask=self.use_attention_mask,
|
|
237
|
-
use_position_ids=self.use_position_ids,
|
|
238
|
-
)
|
|
239
|
-
else:
|
|
240
|
-
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
|
241
|
-
|
|
242
|
-
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
|
233
|
+
for layer_idx, layer in enumerate(self.get_decoder_layers(causal_lm)):
|
|
234
|
+
is_sliding = layer_idx in self.sliding_window_layers
|
|
235
|
+
new_self_attn = self.get_rbln_attn_class()(
|
|
236
|
+
self.get_attn_layer(layer),
|
|
237
|
+
self.use_attention_mask if not is_sliding else True,
|
|
238
|
+
self.use_position_ids,
|
|
239
|
+
kvcache_block_size=self.sliding_window
|
|
240
|
+
if layer_idx in self.sliding_window_layers
|
|
241
|
+
else self.kvcache_block_size,
|
|
242
|
+
is_sliding=is_sliding,
|
|
243
|
+
attn_impl=self.attn_impl if not is_sliding else "eager",
|
|
244
|
+
kvcache_partition_len=self.kvcache_partition_len,
|
|
245
|
+
)
|
|
246
|
+
new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
|
|
243
247
|
new_layers.append(new_layer)
|
|
244
248
|
|
|
245
|
-
new_model =
|
|
246
|
-
causal_lm
|
|
249
|
+
new_model = self.get_rbln_model_class()(
|
|
250
|
+
self.get_model_layer(causal_lm),
|
|
247
251
|
new_layers,
|
|
248
252
|
partition_len=self.kvcache_partition_len,
|
|
249
253
|
max_seq_len=max_seq_len,
|
|
250
254
|
kvcache_block_size=self.kvcache_block_size,
|
|
251
|
-
use_learned_pos_emb=self.
|
|
255
|
+
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
252
256
|
sliding_window_layers=self.sliding_window_layers,
|
|
253
257
|
)
|
|
254
|
-
new_causal_lm =
|
|
258
|
+
new_causal_lm = self.get_rbln_causal_lm_class()(causal_lm, new_model)
|
|
255
259
|
return new_causal_lm
|
|
256
260
|
|
|
257
261
|
@property
|
|
@@ -679,9 +683,23 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
679
683
|
|
|
680
684
|
Args:
|
|
681
685
|
self_attn: Original attention module from the base model
|
|
686
|
+
use_attention_mask: Whether to use attention mask
|
|
687
|
+
use_position_ids: Whether to use position ids
|
|
688
|
+
kvcache_block_size: Block size for KV cache
|
|
689
|
+
is_sliding: Whether this is sliding window attention
|
|
690
|
+
attn_impl: Attention implementation type ("eager" or "flash_attn")
|
|
682
691
|
"""
|
|
683
692
|
|
|
684
|
-
def __init__(
|
|
693
|
+
def __init__(
|
|
694
|
+
self,
|
|
695
|
+
self_attn,
|
|
696
|
+
use_attention_mask,
|
|
697
|
+
use_position_ids,
|
|
698
|
+
kvcache_block_size,
|
|
699
|
+
is_sliding=False,
|
|
700
|
+
attn_impl="eager",
|
|
701
|
+
kvcache_partition_len=None,
|
|
702
|
+
):
|
|
685
703
|
super().__init__()
|
|
686
704
|
self._original_mod = self_attn
|
|
687
705
|
self.layer_idx = self_attn.layer_idx
|
|
@@ -702,10 +720,24 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
702
720
|
self.use_attention_mask = use_attention_mask
|
|
703
721
|
self.use_position_ids = use_position_ids
|
|
704
722
|
self.is_sliding = is_sliding
|
|
705
|
-
self.
|
|
723
|
+
self.attn_impl = attn_impl
|
|
724
|
+
self.kvcache_partition_len = kvcache_partition_len
|
|
725
|
+
|
|
726
|
+
setattr(self, self.get_attention_name(), self.create_attention_op())
|
|
706
727
|
self.kvcache_block_size = kvcache_block_size
|
|
707
728
|
self.__post_init__()
|
|
708
729
|
|
|
730
|
+
def get_attention_name(self):
|
|
731
|
+
if self.is_sliding:
|
|
732
|
+
return "sliding_window_attention"
|
|
733
|
+
elif self.attn_impl == "flash_attn":
|
|
734
|
+
return "flash_attention"
|
|
735
|
+
else:
|
|
736
|
+
return "attention"
|
|
737
|
+
|
|
738
|
+
def get_attention_op(self):
|
|
739
|
+
return getattr(self, self.get_attention_name())
|
|
740
|
+
|
|
709
741
|
@property
|
|
710
742
|
def phase(self):
|
|
711
743
|
return self._phase
|
|
@@ -713,17 +745,36 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
713
745
|
@phase.setter
|
|
714
746
|
def phase(self, phase: str):
|
|
715
747
|
self._phase = phase
|
|
716
|
-
self.
|
|
748
|
+
getattr(self, self.get_attention_name()).phase = phase
|
|
717
749
|
|
|
718
|
-
def
|
|
750
|
+
def create_attention_op(self):
|
|
719
751
|
if self.is_sliding:
|
|
720
752
|
return SlidingWindowAttentionOp(
|
|
721
|
-
self.num_heads,
|
|
753
|
+
self.num_heads,
|
|
754
|
+
self.head_dim,
|
|
755
|
+
self.num_key_value_heads,
|
|
756
|
+
self.use_attention_mask,
|
|
757
|
+
self.use_position_ids,
|
|
722
758
|
)
|
|
723
|
-
|
|
759
|
+
elif self.attn_impl == "flash_attn":
|
|
760
|
+
return FlashAttentionOp(
|
|
761
|
+
self.num_heads,
|
|
762
|
+
self.head_dim,
|
|
763
|
+
self.num_key_value_heads,
|
|
764
|
+
self.kvcache_partition_len,
|
|
765
|
+
self.use_attention_mask,
|
|
766
|
+
self.use_position_ids,
|
|
767
|
+
)
|
|
768
|
+
elif self.attn_impl == "eager":
|
|
724
769
|
return AttentionOp(
|
|
725
|
-
self.num_heads,
|
|
770
|
+
self.num_heads,
|
|
771
|
+
self.head_dim,
|
|
772
|
+
self.num_key_value_heads,
|
|
773
|
+
self.use_attention_mask,
|
|
774
|
+
self.use_position_ids,
|
|
726
775
|
)
|
|
776
|
+
else:
|
|
777
|
+
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
727
778
|
|
|
728
779
|
def __post_init__(self):
|
|
729
780
|
self.q_proj = self._original_mod.q_proj
|
|
@@ -780,7 +831,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
780
831
|
if batch_size > 1 and "prefill" in self.phase:
|
|
781
832
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
|
782
833
|
|
|
783
|
-
attn_output = self.
|
|
834
|
+
attn_output = self.get_attention_op()(
|
|
784
835
|
query_states,
|
|
785
836
|
key_states,
|
|
786
837
|
value_states,
|
|
@@ -797,6 +848,14 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
797
848
|
return attn_outputs
|
|
798
849
|
|
|
799
850
|
|
|
851
|
+
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
852
|
+
def __init__(self, *args, **kwargs):
|
|
853
|
+
super().__init__(*args, **kwargs)
|
|
854
|
+
logger.warning(
|
|
855
|
+
"DecoderOnlyFlashAttention is deprecated and may not work as expected. Use DecoderOnlyAttention instead."
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
|
|
800
859
|
class AttentionOp(nn.Module):
|
|
801
860
|
def __init__(
|
|
802
861
|
self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
|
|
@@ -809,6 +868,18 @@ class AttentionOp(nn.Module):
|
|
|
809
868
|
self.use_attention_mask = use_attention_mask
|
|
810
869
|
self.use_position_ids = use_position_ids
|
|
811
870
|
|
|
871
|
+
def get_attn_op_name(self):
|
|
872
|
+
phase = "decode" if self.phase == "decode" else "prefill"
|
|
873
|
+
|
|
874
|
+
if self.use_attention_mask and not self.use_position_ids:
|
|
875
|
+
attn_op_name = "paged_attn_"
|
|
876
|
+
else:
|
|
877
|
+
attn_op_name = "paged_causal_attn_"
|
|
878
|
+
|
|
879
|
+
attn_op_name += phase
|
|
880
|
+
|
|
881
|
+
return attn_op_name
|
|
882
|
+
|
|
812
883
|
def forward(
|
|
813
884
|
self,
|
|
814
885
|
query_state: torch.Tensor,
|
|
@@ -857,63 +928,31 @@ class AttentionOp(nn.Module):
|
|
|
857
928
|
self.head_dim,
|
|
858
929
|
)
|
|
859
930
|
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
mask=attn_mask if self.use_position_ids else None,
|
|
886
|
-
)
|
|
887
|
-
|
|
888
|
-
else:
|
|
889
|
-
if self.use_attention_mask and not self.use_position_ids:
|
|
890
|
-
attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
|
891
|
-
q=query_state,
|
|
892
|
-
k=key_state,
|
|
893
|
-
v=value_state,
|
|
894
|
-
mask=attn_mask,
|
|
895
|
-
kcache=past_key_state.unsqueeze(2),
|
|
896
|
-
vcache=past_value_state.unsqueeze(2),
|
|
897
|
-
seq=seq_position,
|
|
898
|
-
scale=scale,
|
|
899
|
-
block_table=block_tables,
|
|
900
|
-
block_size=block_size,
|
|
901
|
-
)
|
|
902
|
-
else:
|
|
903
|
-
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
|
|
904
|
-
q=query_state,
|
|
905
|
-
k=key_state,
|
|
906
|
-
v=value_state,
|
|
907
|
-
kcache=past_key_state.unsqueeze(2),
|
|
908
|
-
vcache=past_value_state.unsqueeze(2),
|
|
909
|
-
seq=seq_position,
|
|
910
|
-
scale=scale,
|
|
911
|
-
block_table=block_tables,
|
|
912
|
-
block_size=block_size,
|
|
913
|
-
is_bidirectional=True if self.phase == "image_prefill" else False, # FIXME, Hard-coded for Gemma3.
|
|
914
|
-
mask=attn_mask if self.use_position_ids else None,
|
|
915
|
-
)
|
|
916
|
-
|
|
931
|
+
op_args = {
|
|
932
|
+
"q": query_state,
|
|
933
|
+
"k": key_state,
|
|
934
|
+
"v": value_state,
|
|
935
|
+
"kcache": past_key_state.unsqueeze(2),
|
|
936
|
+
"vcache": past_value_state.unsqueeze(2),
|
|
937
|
+
"seq": seq_position,
|
|
938
|
+
"scale": scale,
|
|
939
|
+
"block_table": block_tables,
|
|
940
|
+
"block_size": block_size,
|
|
941
|
+
}
|
|
942
|
+
|
|
943
|
+
if self.use_attention_mask != self.use_position_ids:
|
|
944
|
+
op_args["mask"] = attn_mask
|
|
945
|
+
|
|
946
|
+
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
947
|
+
if not self.use_attention_mask or self.use_position_ids:
|
|
948
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
949
|
+
|
|
950
|
+
attn_op_name = self.get_attn_op_name()
|
|
951
|
+
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
952
|
+
if attn_op is None:
|
|
953
|
+
raise ValueError(f"Attention operator {attn_op_name} not found.")
|
|
954
|
+
|
|
955
|
+
attn_output = attn_op(**op_args)
|
|
917
956
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
|
918
957
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
919
958
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
@@ -1012,70 +1051,6 @@ class RotaryEmbedding(nn.Module):
|
|
|
1012
1051
|
)
|
|
1013
1052
|
|
|
1014
1053
|
|
|
1015
|
-
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
1016
|
-
def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask, use_position_ids):
|
|
1017
|
-
self.kvcache_partition_size = kvcache_partition_len
|
|
1018
|
-
super().__init__(
|
|
1019
|
-
self_attn=self_attn,
|
|
1020
|
-
use_attention_mask=use_attention_mask,
|
|
1021
|
-
use_position_ids=use_position_ids,
|
|
1022
|
-
kvcache_block_size=kvcache_block_size,
|
|
1023
|
-
)
|
|
1024
|
-
|
|
1025
|
-
def get_attention(self):
|
|
1026
|
-
return FlashAttentionOp(
|
|
1027
|
-
self.num_heads,
|
|
1028
|
-
self.head_dim,
|
|
1029
|
-
self.num_key_value_heads,
|
|
1030
|
-
self.kvcache_partition_size,
|
|
1031
|
-
self.use_attention_mask,
|
|
1032
|
-
self.use_position_ids,
|
|
1033
|
-
)
|
|
1034
|
-
|
|
1035
|
-
def forward(
|
|
1036
|
-
self,
|
|
1037
|
-
hidden_states: torch.Tensor,
|
|
1038
|
-
attention_mask: torch.Tensor,
|
|
1039
|
-
seq_positions: torch.LongTensor,
|
|
1040
|
-
past_key_values: Tuple[Tuple[torch.Tensor]],
|
|
1041
|
-
cos: Optional[torch.Tensor] = None,
|
|
1042
|
-
sin: Optional[torch.Tensor] = None,
|
|
1043
|
-
block_tables: Optional[torch.Tensor] = None,
|
|
1044
|
-
):
|
|
1045
|
-
batch_size, query_length, _ = hidden_states.size()
|
|
1046
|
-
|
|
1047
|
-
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
|
1048
|
-
|
|
1049
|
-
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
1050
|
-
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
1051
|
-
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
|
1052
|
-
1, 2
|
|
1053
|
-
)
|
|
1054
|
-
|
|
1055
|
-
if hasattr(self, "q_norm") and hasattr(self, "k_norm"):
|
|
1056
|
-
query_states = self.q_norm(query_states)
|
|
1057
|
-
key_states = self.k_norm(key_states)
|
|
1058
|
-
|
|
1059
|
-
if cos is not None and sin is not None:
|
|
1060
|
-
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
|
1061
|
-
|
|
1062
|
-
attn_output = self.attention(
|
|
1063
|
-
query_states,
|
|
1064
|
-
key_states,
|
|
1065
|
-
value_states,
|
|
1066
|
-
attention_mask,
|
|
1067
|
-
past_key_state=past_key_values[self.layer_idx][0],
|
|
1068
|
-
past_value_state=past_key_values[self.layer_idx][1],
|
|
1069
|
-
seq_position=seq_positions,
|
|
1070
|
-
scale=self.scale,
|
|
1071
|
-
block_tables=block_tables,
|
|
1072
|
-
kvcache_block_size=self.kvcache_block_size,
|
|
1073
|
-
)
|
|
1074
|
-
|
|
1075
|
-
attn_outputs = self.o_proj(attn_output)
|
|
1076
|
-
return attn_outputs
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
1054
|
class FlashAttentionOp(AttentionOp):
|
|
1080
1055
|
def __init__(
|
|
1081
1056
|
self,
|
|
@@ -1095,6 +1070,17 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1095
1070
|
)
|
|
1096
1071
|
self.kvcache_partition_size = kvcache_partition_len
|
|
1097
1072
|
|
|
1073
|
+
def get_attn_op_name(self):
|
|
1074
|
+
phase = "decode" if self.phase == "decode" else "prefill"
|
|
1075
|
+
if self.use_attention_mask and not self.use_position_ids:
|
|
1076
|
+
attn_op_name = "paged_flash_attn_"
|
|
1077
|
+
else:
|
|
1078
|
+
attn_op_name = "paged_flash_causal_attn_"
|
|
1079
|
+
|
|
1080
|
+
attn_op_name += phase
|
|
1081
|
+
|
|
1082
|
+
return attn_op_name
|
|
1083
|
+
|
|
1098
1084
|
def forward(
|
|
1099
1085
|
self,
|
|
1100
1086
|
query_state,
|
|
@@ -1106,7 +1092,7 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1106
1092
|
seq_position,
|
|
1107
1093
|
scale,
|
|
1108
1094
|
block_tables,
|
|
1109
|
-
|
|
1095
|
+
block_size,
|
|
1110
1096
|
):
|
|
1111
1097
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1112
1098
|
key_state = key_state.unsqueeze(2)
|
|
@@ -1127,67 +1113,32 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1127
1113
|
self.head_dim,
|
|
1128
1114
|
)
|
|
1129
1115
|
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
partition=self.kvcache_partition_size,
|
|
1157
|
-
mask=attn_mask if self.use_position_ids else None,
|
|
1158
|
-
)
|
|
1159
|
-
else:
|
|
1160
|
-
if self.use_attention_mask and not self.use_position_ids:
|
|
1161
|
-
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
|
1162
|
-
q=query_state,
|
|
1163
|
-
k=key_state,
|
|
1164
|
-
v=value_state,
|
|
1165
|
-
mask=attn_mask,
|
|
1166
|
-
kcache=past_key_state.unsqueeze(2),
|
|
1167
|
-
vcache=past_value_state.unsqueeze(2),
|
|
1168
|
-
seq=seq_position,
|
|
1169
|
-
scale=scale,
|
|
1170
|
-
block_table=block_tables,
|
|
1171
|
-
block_size=kvcache_block_size,
|
|
1172
|
-
partition=self.kvcache_partition_size,
|
|
1173
|
-
)
|
|
1174
|
-
else:
|
|
1175
|
-
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
|
|
1176
|
-
q=query_state,
|
|
1177
|
-
k=key_state,
|
|
1178
|
-
v=value_state,
|
|
1179
|
-
kcache=past_key_state.unsqueeze(2),
|
|
1180
|
-
vcache=past_value_state.unsqueeze(2),
|
|
1181
|
-
seq=seq_position,
|
|
1182
|
-
scale=scale,
|
|
1183
|
-
block_table=block_tables,
|
|
1184
|
-
block_size=kvcache_block_size,
|
|
1185
|
-
partition=self.kvcache_partition_size,
|
|
1186
|
-
is_bidirectional=True if self.phase == "image_prefill" else False,
|
|
1187
|
-
mask=attn_mask if self.use_position_ids else None,
|
|
1188
|
-
)
|
|
1189
|
-
|
|
1190
|
-
# reshape for removing repeat_kv
|
|
1116
|
+
op_args = {
|
|
1117
|
+
"q": query_state,
|
|
1118
|
+
"k": key_state,
|
|
1119
|
+
"v": value_state,
|
|
1120
|
+
"kcache": past_key_state.unsqueeze(2),
|
|
1121
|
+
"vcache": past_value_state.unsqueeze(2),
|
|
1122
|
+
"seq": seq_position,
|
|
1123
|
+
"scale": scale,
|
|
1124
|
+
"block_table": block_tables,
|
|
1125
|
+
"block_size": block_size,
|
|
1126
|
+
"partition": self.kvcache_partition_size,
|
|
1127
|
+
}
|
|
1128
|
+
|
|
1129
|
+
if self.use_attention_mask:
|
|
1130
|
+
op_args["mask"] = attn_mask
|
|
1131
|
+
|
|
1132
|
+
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1133
|
+
if not self.use_attention_mask or self.use_position_ids:
|
|
1134
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1135
|
+
|
|
1136
|
+
attn_op_name = self.get_attn_op_name()
|
|
1137
|
+
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
1138
|
+
if attn_op is None:
|
|
1139
|
+
raise ValueError(f"Attention operator {attn_op_name} not found.")
|
|
1140
|
+
|
|
1141
|
+
attn_output = attn_op(**op_args)
|
|
1191
1142
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
|
1192
1143
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1193
1144
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
@@ -1196,6 +1147,14 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1196
1147
|
|
|
1197
1148
|
|
|
1198
1149
|
class SlidingWindowAttentionOp(AttentionOp):
|
|
1150
|
+
def get_attn_op_name(self):
|
|
1151
|
+
phase = "decode" if self.phase == "decode" else "prefill"
|
|
1152
|
+
if not self.use_attention_mask:
|
|
1153
|
+
raise NotImplementedError("Attention mask is needed for sliding window attention.")
|
|
1154
|
+
|
|
1155
|
+
attn_op_name = "paged_sliding_window_attn_" + phase
|
|
1156
|
+
return attn_op_name
|
|
1157
|
+
|
|
1199
1158
|
def forward(
|
|
1200
1159
|
self,
|
|
1201
1160
|
query_state: torch.Tensor,
|
|
@@ -1226,35 +1185,29 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1226
1185
|
self.head_dim,
|
|
1227
1186
|
)
|
|
1228
1187
|
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
block_table=block_tables,
|
|
1253
|
-
block_size=block_size,
|
|
1254
|
-
is_bidirectional=True if self.phase == "image_prefill" else False,
|
|
1255
|
-
)
|
|
1256
|
-
|
|
1257
|
-
# reshape for removing repeat_kv
|
|
1188
|
+
op_args = {
|
|
1189
|
+
"q": query_state,
|
|
1190
|
+
"k": key_state,
|
|
1191
|
+
"v": value_state,
|
|
1192
|
+
"kcache": past_key_state.unsqueeze(2),
|
|
1193
|
+
"vcache": past_value_state.unsqueeze(2),
|
|
1194
|
+
"cache_seq_len": seq_position[0],
|
|
1195
|
+
"cache_offset": seq_position[1],
|
|
1196
|
+
"scale": scale,
|
|
1197
|
+
"block_table": block_tables,
|
|
1198
|
+
"block_size": block_size,
|
|
1199
|
+
}
|
|
1200
|
+
|
|
1201
|
+
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1202
|
+
if not self.use_attention_mask or self.use_position_ids:
|
|
1203
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1204
|
+
|
|
1205
|
+
attn_op_name = self.get_attn_op_name()
|
|
1206
|
+
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
1207
|
+
if attn_op is None:
|
|
1208
|
+
raise ValueError(f"Attention operator {attn_op_name} not found.")
|
|
1209
|
+
|
|
1210
|
+
attn_output = attn_op(**op_args)
|
|
1258
1211
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
|
1259
1212
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1260
1213
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
@@ -19,8 +19,6 @@ import torch.nn as nn
|
|
|
19
19
|
from ....utils import logging
|
|
20
20
|
from ...models.decoderonly.decoderonly_architecture import (
|
|
21
21
|
DecoderOnlyAttention,
|
|
22
|
-
DecoderOnlyFlashAttention,
|
|
23
|
-
DecoderOnlyForCausalLM,
|
|
24
22
|
DecoderOnlyLayer,
|
|
25
23
|
DecoderOnlyModel,
|
|
26
24
|
DecoderOnlyWrapper,
|
|
@@ -36,38 +34,23 @@ logger = logging.get_logger(__name__)
|
|
|
36
34
|
class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
37
35
|
"""A wrapper class for the Exaone model with a language modeling head."""
|
|
38
36
|
|
|
39
|
-
def
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
)
|
|
57
|
-
else:
|
|
58
|
-
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
|
59
|
-
|
|
60
|
-
new_layer = ExaoneLayer(layer, new_self_attn)
|
|
61
|
-
new_layers.append(new_layer)
|
|
62
|
-
new_model = ExaoneModel(
|
|
63
|
-
causal_lm.transformer,
|
|
64
|
-
new_layers,
|
|
65
|
-
partition_len=self.kvcache_partition_len,
|
|
66
|
-
max_seq_len=max_seq_len,
|
|
67
|
-
sliding_window_layers=self.sliding_window_layers,
|
|
68
|
-
)
|
|
69
|
-
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
|
70
|
-
return new_causal_lm
|
|
37
|
+
def get_decoder_layers(self, causal_lm: "ExaoneForCausalLM"):
|
|
38
|
+
return causal_lm.transformer.h
|
|
39
|
+
|
|
40
|
+
def get_attn_layer(self, layer: nn.Module):
|
|
41
|
+
return layer.attn.attention
|
|
42
|
+
|
|
43
|
+
def get_model_layer(self, causal_lm: "ExaoneForCausalLM"):
|
|
44
|
+
return causal_lm.transformer
|
|
45
|
+
|
|
46
|
+
def get_rbln_attn_class(self):
|
|
47
|
+
return ExaoneAttention
|
|
48
|
+
|
|
49
|
+
def get_rbln_layer_class(self):
|
|
50
|
+
return ExaoneLayer
|
|
51
|
+
|
|
52
|
+
def get_rbln_model_class(self):
|
|
53
|
+
return ExaoneModel
|
|
71
54
|
|
|
72
55
|
|
|
73
56
|
class ExaoneModel(DecoderOnlyModel):
|
|
@@ -92,11 +75,3 @@ class ExaoneAttention(DecoderOnlyAttention):
|
|
|
92
75
|
self.k_proj = self._original_mod.k_proj
|
|
93
76
|
self.v_proj = self._original_mod.v_proj
|
|
94
77
|
self.o_proj = self._original_mod.out_proj
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
class ExaoneFlashAttention(DecoderOnlyFlashAttention):
|
|
98
|
-
def __post_init__(self):
|
|
99
|
-
self.q_proj = self._original_mod.q_proj
|
|
100
|
-
self.k_proj = self._original_mod.k_proj
|
|
101
|
-
self.v_proj = self._original_mod.v_proj
|
|
102
|
-
self.o_proj = self._original_mod.out_proj
|