optimum-rbln 0.8.1rc0__py3-none-any.whl → 0.8.2a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/configurations/models/__init__.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +5 -3
- optimum/rbln/diffusers/configurations/models/{configuration_cosmos_transformer.py → configuration_transformer_cosmos.py} +4 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +10 -6
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +5 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +212 -257
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +2 -40
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +18 -22
- optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
- optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
- optimum/rbln/transformers/models/phi/phi_architecture.py +14 -20
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2a0.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2a0.dist-info}/RECORD +22 -22
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2a0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2a0.dist-info}/licenses/LICENSE +0 -0
|
@@ -149,6 +149,8 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
149
149
|
This is only relevant if `attn_impl` is set to "flash_attn`
|
|
150
150
|
"""
|
|
151
151
|
|
|
152
|
+
_use_learned_pos_emb = False
|
|
153
|
+
|
|
152
154
|
def __init__(
|
|
153
155
|
self,
|
|
154
156
|
causal_lm: PreTrainedModel,
|
|
@@ -159,7 +161,6 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
159
161
|
use_inputs_embeds: bool,
|
|
160
162
|
use_attention_mask: bool,
|
|
161
163
|
use_position_ids: bool,
|
|
162
|
-
use_learned_pos_emb: Optional[bool] = None,
|
|
163
164
|
kvcache_partition_len: Optional[int] = None,
|
|
164
165
|
kvcache_block_size: Optional[int] = None,
|
|
165
166
|
sliding_window: Optional[int] = None,
|
|
@@ -182,7 +183,6 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
182
183
|
self.use_attention_mask = use_attention_mask
|
|
183
184
|
self.use_position_ids = use_position_ids
|
|
184
185
|
self.use_inputs_embeds = use_inputs_embeds
|
|
185
|
-
self.use_learned_pos_emb = use_learned_pos_emb
|
|
186
186
|
self.sliding_window_layers = sliding_window_layers
|
|
187
187
|
self.cache_impl = cache_impl
|
|
188
188
|
self.sliding_window = sliding_window
|
|
@@ -207,51 +207,54 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
207
207
|
def get_rotary_emb(self, max_seq_len):
|
|
208
208
|
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
|
209
209
|
|
|
210
|
+
def get_decoder_layers(self, causal_lm: PreTrainedModel):
|
|
211
|
+
return causal_lm.model.layers
|
|
212
|
+
|
|
213
|
+
def get_attn_layer(self, layer: nn.Module):
|
|
214
|
+
return layer.self_attn
|
|
215
|
+
|
|
216
|
+
def get_model_layer(self, causal_lm: PreTrainedModel):
|
|
217
|
+
return causal_lm.model
|
|
218
|
+
|
|
219
|
+
def get_rbln_attn_class(self):
|
|
220
|
+
return DecoderOnlyAttention
|
|
221
|
+
|
|
222
|
+
def get_rbln_layer_class(self):
|
|
223
|
+
return DecoderOnlyLayer
|
|
224
|
+
|
|
225
|
+
def get_rbln_model_class(self):
|
|
226
|
+
return DecoderOnlyModel
|
|
227
|
+
|
|
228
|
+
def get_rbln_causal_lm_class(self):
|
|
229
|
+
return DecoderOnlyForCausalLM
|
|
230
|
+
|
|
210
231
|
def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
|
|
211
232
|
new_layers = []
|
|
212
|
-
for layer_idx, layer in enumerate(causal_lm
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
layer.self_attn,
|
|
226
|
-
self.use_attention_mask,
|
|
227
|
-
self.use_position_ids,
|
|
228
|
-
kvcache_block_size=self.kvcache_block_size,
|
|
229
|
-
is_sliding=False,
|
|
230
|
-
)
|
|
231
|
-
elif self.attn_impl == "flash_attn":
|
|
232
|
-
new_self_attn = DecoderOnlyFlashAttention(
|
|
233
|
-
layer.self_attn,
|
|
234
|
-
kvcache_partition_len=self.kvcache_partition_len,
|
|
235
|
-
kvcache_block_size=self.kvcache_block_size,
|
|
236
|
-
use_attention_mask=self.use_attention_mask,
|
|
237
|
-
use_position_ids=self.use_position_ids,
|
|
238
|
-
)
|
|
239
|
-
else:
|
|
240
|
-
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
|
241
|
-
|
|
242
|
-
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
|
233
|
+
for layer_idx, layer in enumerate(self.get_decoder_layers(causal_lm)):
|
|
234
|
+
new_self_attn = self.get_rbln_attn_class()(
|
|
235
|
+
self.get_attn_layer(layer),
|
|
236
|
+
self.use_attention_mask,
|
|
237
|
+
self.use_position_ids,
|
|
238
|
+
kvcache_block_size=self.sliding_window
|
|
239
|
+
if layer_idx in self.sliding_window_layers
|
|
240
|
+
else self.kvcache_block_size,
|
|
241
|
+
is_sliding=layer_idx in self.sliding_window_layers,
|
|
242
|
+
attn_impl=self.attn_impl,
|
|
243
|
+
kvcache_partition_len=self.kvcache_partition_len,
|
|
244
|
+
)
|
|
245
|
+
new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
|
|
243
246
|
new_layers.append(new_layer)
|
|
244
247
|
|
|
245
|
-
new_model =
|
|
246
|
-
causal_lm
|
|
248
|
+
new_model = self.get_rbln_model_class()(
|
|
249
|
+
self.get_model_layer(causal_lm),
|
|
247
250
|
new_layers,
|
|
248
251
|
partition_len=self.kvcache_partition_len,
|
|
249
252
|
max_seq_len=max_seq_len,
|
|
250
253
|
kvcache_block_size=self.kvcache_block_size,
|
|
251
|
-
use_learned_pos_emb=self.
|
|
254
|
+
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
252
255
|
sliding_window_layers=self.sliding_window_layers,
|
|
253
256
|
)
|
|
254
|
-
new_causal_lm =
|
|
257
|
+
new_causal_lm = self.get_rbln_causal_lm_class()(causal_lm, new_model)
|
|
255
258
|
return new_causal_lm
|
|
256
259
|
|
|
257
260
|
@property
|
|
@@ -679,9 +682,23 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
679
682
|
|
|
680
683
|
Args:
|
|
681
684
|
self_attn: Original attention module from the base model
|
|
685
|
+
use_attention_mask: Whether to use attention mask
|
|
686
|
+
use_position_ids: Whether to use position ids
|
|
687
|
+
kvcache_block_size: Block size for KV cache
|
|
688
|
+
is_sliding: Whether this is sliding window attention
|
|
689
|
+
attn_impl: Attention implementation type ("eager" or "flash_attn")
|
|
682
690
|
"""
|
|
683
691
|
|
|
684
|
-
def __init__(
|
|
692
|
+
def __init__(
|
|
693
|
+
self,
|
|
694
|
+
self_attn,
|
|
695
|
+
use_attention_mask,
|
|
696
|
+
use_position_ids,
|
|
697
|
+
kvcache_block_size,
|
|
698
|
+
is_sliding=False,
|
|
699
|
+
attn_impl="eager",
|
|
700
|
+
kvcache_partition_len=None,
|
|
701
|
+
):
|
|
685
702
|
super().__init__()
|
|
686
703
|
self._original_mod = self_attn
|
|
687
704
|
self.layer_idx = self_attn.layer_idx
|
|
@@ -702,10 +719,28 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
702
719
|
self.use_attention_mask = use_attention_mask
|
|
703
720
|
self.use_position_ids = use_position_ids
|
|
704
721
|
self.is_sliding = is_sliding
|
|
705
|
-
self.
|
|
722
|
+
self.attn_impl = attn_impl
|
|
723
|
+
|
|
724
|
+
if self.is_sliding and self.attn_impl != "eager":
|
|
725
|
+
raise NotImplementedError("Sliding window attention is only supported with eager attention.")
|
|
726
|
+
|
|
727
|
+
self.kvcache_partition_len = kvcache_partition_len
|
|
728
|
+
|
|
729
|
+
setattr(self, self.get_attention_name(), self.create_attention_op())
|
|
706
730
|
self.kvcache_block_size = kvcache_block_size
|
|
707
731
|
self.__post_init__()
|
|
708
732
|
|
|
733
|
+
def get_attention_name(self):
|
|
734
|
+
if self.is_sliding:
|
|
735
|
+
return "sliding_window_attention"
|
|
736
|
+
elif self.attn_impl == "flash_attn":
|
|
737
|
+
return "flash_attention"
|
|
738
|
+
else:
|
|
739
|
+
return "attention"
|
|
740
|
+
|
|
741
|
+
def get_attention_op(self):
|
|
742
|
+
return getattr(self, self.get_attention_name())
|
|
743
|
+
|
|
709
744
|
@property
|
|
710
745
|
def phase(self):
|
|
711
746
|
return self._phase
|
|
@@ -713,17 +748,36 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
713
748
|
@phase.setter
|
|
714
749
|
def phase(self, phase: str):
|
|
715
750
|
self._phase = phase
|
|
716
|
-
self.
|
|
751
|
+
getattr(self, self.get_attention_name()).phase = phase
|
|
717
752
|
|
|
718
|
-
def
|
|
753
|
+
def create_attention_op(self):
|
|
719
754
|
if self.is_sliding:
|
|
720
755
|
return SlidingWindowAttentionOp(
|
|
721
|
-
self.num_heads,
|
|
756
|
+
self.num_heads,
|
|
757
|
+
self.head_dim,
|
|
758
|
+
self.num_key_value_heads,
|
|
759
|
+
self.use_attention_mask,
|
|
760
|
+
self.use_position_ids,
|
|
722
761
|
)
|
|
723
|
-
|
|
762
|
+
elif self.attn_impl == "flash_attn":
|
|
763
|
+
return FlashAttentionOp(
|
|
764
|
+
self.num_heads,
|
|
765
|
+
self.head_dim,
|
|
766
|
+
self.num_key_value_heads,
|
|
767
|
+
self.kvcache_partition_len,
|
|
768
|
+
self.use_attention_mask,
|
|
769
|
+
self.use_position_ids,
|
|
770
|
+
)
|
|
771
|
+
elif self.attn_impl == "eager":
|
|
724
772
|
return AttentionOp(
|
|
725
|
-
self.num_heads,
|
|
773
|
+
self.num_heads,
|
|
774
|
+
self.head_dim,
|
|
775
|
+
self.num_key_value_heads,
|
|
776
|
+
self.use_attention_mask,
|
|
777
|
+
self.use_position_ids,
|
|
726
778
|
)
|
|
779
|
+
else:
|
|
780
|
+
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
727
781
|
|
|
728
782
|
def __post_init__(self):
|
|
729
783
|
self.q_proj = self._original_mod.q_proj
|
|
@@ -780,7 +834,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
780
834
|
if batch_size > 1 and "prefill" in self.phase:
|
|
781
835
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
|
782
836
|
|
|
783
|
-
attn_output = self.
|
|
837
|
+
attn_output = self.get_attention_op()(
|
|
784
838
|
query_states,
|
|
785
839
|
key_states,
|
|
786
840
|
value_states,
|
|
@@ -797,6 +851,14 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
797
851
|
return attn_outputs
|
|
798
852
|
|
|
799
853
|
|
|
854
|
+
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
855
|
+
def __init__(self, *args, **kwargs):
|
|
856
|
+
super().__init__(*args, **kwargs)
|
|
857
|
+
logger.warning(
|
|
858
|
+
"DecoderOnlyFlashAttention is deprecated and may not work as expected. Use DecoderOnlyAttention instead."
|
|
859
|
+
)
|
|
860
|
+
|
|
861
|
+
|
|
800
862
|
class AttentionOp(nn.Module):
|
|
801
863
|
def __init__(
|
|
802
864
|
self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
|
|
@@ -809,6 +871,17 @@ class AttentionOp(nn.Module):
|
|
|
809
871
|
self.use_attention_mask = use_attention_mask
|
|
810
872
|
self.use_position_ids = use_position_ids
|
|
811
873
|
|
|
874
|
+
def get_attn_op_name(self):
|
|
875
|
+
phase = "decode" if self.phase == "decode" else "prefill"
|
|
876
|
+
if self.use_attention_mask:
|
|
877
|
+
attn_op_name = "paged_attn_"
|
|
878
|
+
else:
|
|
879
|
+
attn_op_name = "paged_causal_attn_"
|
|
880
|
+
|
|
881
|
+
attn_op_name += phase
|
|
882
|
+
|
|
883
|
+
return attn_op_name
|
|
884
|
+
|
|
812
885
|
def forward(
|
|
813
886
|
self,
|
|
814
887
|
query_state: torch.Tensor,
|
|
@@ -857,63 +930,31 @@ class AttentionOp(nn.Module):
|
|
|
857
930
|
self.head_dim,
|
|
858
931
|
)
|
|
859
932
|
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
mask=attn_mask if self.use_position_ids else None,
|
|
886
|
-
)
|
|
887
|
-
|
|
888
|
-
else:
|
|
889
|
-
if self.use_attention_mask and not self.use_position_ids:
|
|
890
|
-
attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
|
891
|
-
q=query_state,
|
|
892
|
-
k=key_state,
|
|
893
|
-
v=value_state,
|
|
894
|
-
mask=attn_mask,
|
|
895
|
-
kcache=past_key_state.unsqueeze(2),
|
|
896
|
-
vcache=past_value_state.unsqueeze(2),
|
|
897
|
-
seq=seq_position,
|
|
898
|
-
scale=scale,
|
|
899
|
-
block_table=block_tables,
|
|
900
|
-
block_size=block_size,
|
|
901
|
-
)
|
|
902
|
-
else:
|
|
903
|
-
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
|
|
904
|
-
q=query_state,
|
|
905
|
-
k=key_state,
|
|
906
|
-
v=value_state,
|
|
907
|
-
kcache=past_key_state.unsqueeze(2),
|
|
908
|
-
vcache=past_value_state.unsqueeze(2),
|
|
909
|
-
seq=seq_position,
|
|
910
|
-
scale=scale,
|
|
911
|
-
block_table=block_tables,
|
|
912
|
-
block_size=block_size,
|
|
913
|
-
is_bidirectional=True if self.phase == "image_prefill" else False, # FIXME, Hard-coded for Gemma3.
|
|
914
|
-
mask=attn_mask if self.use_position_ids else None,
|
|
915
|
-
)
|
|
916
|
-
|
|
933
|
+
op_args = {
|
|
934
|
+
"q": query_state,
|
|
935
|
+
"k": key_state,
|
|
936
|
+
"v": value_state,
|
|
937
|
+
"kcache": past_key_state.unsqueeze(2),
|
|
938
|
+
"vcache": past_value_state.unsqueeze(2),
|
|
939
|
+
"seq": seq_position,
|
|
940
|
+
"scale": scale,
|
|
941
|
+
"block_table": block_tables,
|
|
942
|
+
"block_size": block_size,
|
|
943
|
+
}
|
|
944
|
+
|
|
945
|
+
if self.use_attention_mask != self.use_position_ids:
|
|
946
|
+
op_args["mask"] = attn_mask
|
|
947
|
+
|
|
948
|
+
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
949
|
+
if not self.use_attention_mask or self.use_position_ids:
|
|
950
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
951
|
+
|
|
952
|
+
attn_op_name = self.get_attn_op_name()
|
|
953
|
+
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
954
|
+
if attn_op is None:
|
|
955
|
+
raise ValueError(f"Attention operator {attn_op_name} not found.")
|
|
956
|
+
|
|
957
|
+
attn_output = attn_op(**op_args)
|
|
917
958
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
|
918
959
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
919
960
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
@@ -1012,70 +1053,6 @@ class RotaryEmbedding(nn.Module):
|
|
|
1012
1053
|
)
|
|
1013
1054
|
|
|
1014
1055
|
|
|
1015
|
-
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
1016
|
-
def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask, use_position_ids):
|
|
1017
|
-
self.kvcache_partition_size = kvcache_partition_len
|
|
1018
|
-
super().__init__(
|
|
1019
|
-
self_attn=self_attn,
|
|
1020
|
-
use_attention_mask=use_attention_mask,
|
|
1021
|
-
use_position_ids=use_position_ids,
|
|
1022
|
-
kvcache_block_size=kvcache_block_size,
|
|
1023
|
-
)
|
|
1024
|
-
|
|
1025
|
-
def get_attention(self):
|
|
1026
|
-
return FlashAttentionOp(
|
|
1027
|
-
self.num_heads,
|
|
1028
|
-
self.head_dim,
|
|
1029
|
-
self.num_key_value_heads,
|
|
1030
|
-
self.kvcache_partition_size,
|
|
1031
|
-
self.use_attention_mask,
|
|
1032
|
-
self.use_position_ids,
|
|
1033
|
-
)
|
|
1034
|
-
|
|
1035
|
-
def forward(
|
|
1036
|
-
self,
|
|
1037
|
-
hidden_states: torch.Tensor,
|
|
1038
|
-
attention_mask: torch.Tensor,
|
|
1039
|
-
seq_positions: torch.LongTensor,
|
|
1040
|
-
past_key_values: Tuple[Tuple[torch.Tensor]],
|
|
1041
|
-
cos: Optional[torch.Tensor] = None,
|
|
1042
|
-
sin: Optional[torch.Tensor] = None,
|
|
1043
|
-
block_tables: Optional[torch.Tensor] = None,
|
|
1044
|
-
):
|
|
1045
|
-
batch_size, query_length, _ = hidden_states.size()
|
|
1046
|
-
|
|
1047
|
-
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
|
1048
|
-
|
|
1049
|
-
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
1050
|
-
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
1051
|
-
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
|
1052
|
-
1, 2
|
|
1053
|
-
)
|
|
1054
|
-
|
|
1055
|
-
if hasattr(self, "q_norm") and hasattr(self, "k_norm"):
|
|
1056
|
-
query_states = self.q_norm(query_states)
|
|
1057
|
-
key_states = self.k_norm(key_states)
|
|
1058
|
-
|
|
1059
|
-
if cos is not None and sin is not None:
|
|
1060
|
-
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
|
1061
|
-
|
|
1062
|
-
attn_output = self.attention(
|
|
1063
|
-
query_states,
|
|
1064
|
-
key_states,
|
|
1065
|
-
value_states,
|
|
1066
|
-
attention_mask,
|
|
1067
|
-
past_key_state=past_key_values[self.layer_idx][0],
|
|
1068
|
-
past_value_state=past_key_values[self.layer_idx][1],
|
|
1069
|
-
seq_position=seq_positions,
|
|
1070
|
-
scale=self.scale,
|
|
1071
|
-
block_tables=block_tables,
|
|
1072
|
-
kvcache_block_size=self.kvcache_block_size,
|
|
1073
|
-
)
|
|
1074
|
-
|
|
1075
|
-
attn_outputs = self.o_proj(attn_output)
|
|
1076
|
-
return attn_outputs
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
1056
|
class FlashAttentionOp(AttentionOp):
|
|
1080
1057
|
def __init__(
|
|
1081
1058
|
self,
|
|
@@ -1095,6 +1072,17 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1095
1072
|
)
|
|
1096
1073
|
self.kvcache_partition_size = kvcache_partition_len
|
|
1097
1074
|
|
|
1075
|
+
def get_attn_op_name(self):
|
|
1076
|
+
phase = "decode" if self.phase == "decode" else "prefill"
|
|
1077
|
+
if self.use_attention_mask:
|
|
1078
|
+
attn_op_name = "paged_flash_attn_"
|
|
1079
|
+
else:
|
|
1080
|
+
attn_op_name = "paged_flash_causal_attn_"
|
|
1081
|
+
|
|
1082
|
+
attn_op_name += phase
|
|
1083
|
+
|
|
1084
|
+
return attn_op_name
|
|
1085
|
+
|
|
1098
1086
|
def forward(
|
|
1099
1087
|
self,
|
|
1100
1088
|
query_state,
|
|
@@ -1106,7 +1094,7 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1106
1094
|
seq_position,
|
|
1107
1095
|
scale,
|
|
1108
1096
|
block_tables,
|
|
1109
|
-
|
|
1097
|
+
block_size,
|
|
1110
1098
|
):
|
|
1111
1099
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1112
1100
|
key_state = key_state.unsqueeze(2)
|
|
@@ -1127,67 +1115,32 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1127
1115
|
self.head_dim,
|
|
1128
1116
|
)
|
|
1129
1117
|
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
partition=self.kvcache_partition_size,
|
|
1157
|
-
mask=attn_mask if self.use_position_ids else None,
|
|
1158
|
-
)
|
|
1159
|
-
else:
|
|
1160
|
-
if self.use_attention_mask and not self.use_position_ids:
|
|
1161
|
-
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
|
1162
|
-
q=query_state,
|
|
1163
|
-
k=key_state,
|
|
1164
|
-
v=value_state,
|
|
1165
|
-
mask=attn_mask,
|
|
1166
|
-
kcache=past_key_state.unsqueeze(2),
|
|
1167
|
-
vcache=past_value_state.unsqueeze(2),
|
|
1168
|
-
seq=seq_position,
|
|
1169
|
-
scale=scale,
|
|
1170
|
-
block_table=block_tables,
|
|
1171
|
-
block_size=kvcache_block_size,
|
|
1172
|
-
partition=self.kvcache_partition_size,
|
|
1173
|
-
)
|
|
1174
|
-
else:
|
|
1175
|
-
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
|
|
1176
|
-
q=query_state,
|
|
1177
|
-
k=key_state,
|
|
1178
|
-
v=value_state,
|
|
1179
|
-
kcache=past_key_state.unsqueeze(2),
|
|
1180
|
-
vcache=past_value_state.unsqueeze(2),
|
|
1181
|
-
seq=seq_position,
|
|
1182
|
-
scale=scale,
|
|
1183
|
-
block_table=block_tables,
|
|
1184
|
-
block_size=kvcache_block_size,
|
|
1185
|
-
partition=self.kvcache_partition_size,
|
|
1186
|
-
is_bidirectional=True if self.phase == "image_prefill" else False,
|
|
1187
|
-
mask=attn_mask if self.use_position_ids else None,
|
|
1188
|
-
)
|
|
1189
|
-
|
|
1190
|
-
# reshape for removing repeat_kv
|
|
1118
|
+
op_args = {
|
|
1119
|
+
"q": query_state,
|
|
1120
|
+
"k": key_state,
|
|
1121
|
+
"v": value_state,
|
|
1122
|
+
"kcache": past_key_state.unsqueeze(2),
|
|
1123
|
+
"vcache": past_value_state.unsqueeze(2),
|
|
1124
|
+
"seq": seq_position,
|
|
1125
|
+
"scale": scale,
|
|
1126
|
+
"block_table": block_tables,
|
|
1127
|
+
"block_size": block_size,
|
|
1128
|
+
"partition": self.kvcache_partition_size,
|
|
1129
|
+
}
|
|
1130
|
+
|
|
1131
|
+
if self.use_attention_mask != self.use_position_ids:
|
|
1132
|
+
op_args["mask"] = attn_mask
|
|
1133
|
+
|
|
1134
|
+
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1135
|
+
if not self.use_attention_mask or self.use_position_ids:
|
|
1136
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1137
|
+
|
|
1138
|
+
attn_op_name = self.get_attn_op_name()
|
|
1139
|
+
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
1140
|
+
if attn_op is None:
|
|
1141
|
+
raise ValueError(f"Attention operator {attn_op_name} not found.")
|
|
1142
|
+
|
|
1143
|
+
attn_output = attn_op(**op_args)
|
|
1191
1144
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
|
1192
1145
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1193
1146
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
@@ -1196,6 +1149,14 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1196
1149
|
|
|
1197
1150
|
|
|
1198
1151
|
class SlidingWindowAttentionOp(AttentionOp):
|
|
1152
|
+
def get_attn_op_name(self):
|
|
1153
|
+
phase = "decode" if self.phase == "decode" else "prefill"
|
|
1154
|
+
if self.use_attention_mask:
|
|
1155
|
+
raise NotImplementedError("Attention mask is not supported for sliding window attention.")
|
|
1156
|
+
|
|
1157
|
+
attn_op_name = "paged_sliding_window_attn_" + phase
|
|
1158
|
+
return attn_op_name
|
|
1159
|
+
|
|
1199
1160
|
def forward(
|
|
1200
1161
|
self,
|
|
1201
1162
|
query_state: torch.Tensor,
|
|
@@ -1226,35 +1187,29 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1226
1187
|
self.head_dim,
|
|
1227
1188
|
)
|
|
1228
1189
|
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
block_table=block_tables,
|
|
1253
|
-
block_size=block_size,
|
|
1254
|
-
is_bidirectional=True if self.phase == "image_prefill" else False,
|
|
1255
|
-
)
|
|
1256
|
-
|
|
1257
|
-
# reshape for removing repeat_kv
|
|
1190
|
+
op_args = {
|
|
1191
|
+
"q": query_state,
|
|
1192
|
+
"k": key_state,
|
|
1193
|
+
"v": value_state,
|
|
1194
|
+
"kcache": past_key_state.unsqueeze(2),
|
|
1195
|
+
"vcache": past_value_state.unsqueeze(2),
|
|
1196
|
+
"cache_seq_len": seq_position[0],
|
|
1197
|
+
"cache_offset": seq_position[1],
|
|
1198
|
+
"scale": scale,
|
|
1199
|
+
"block_table": block_tables,
|
|
1200
|
+
"block_size": block_size,
|
|
1201
|
+
}
|
|
1202
|
+
|
|
1203
|
+
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1204
|
+
if not self.use_attention_mask or self.use_position_ids:
|
|
1205
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1206
|
+
|
|
1207
|
+
attn_op_name = self.get_attn_op_name()
|
|
1208
|
+
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
1209
|
+
if attn_op is None:
|
|
1210
|
+
raise ValueError(f"Attention operator {attn_op_name} not found.")
|
|
1211
|
+
|
|
1212
|
+
attn_output = attn_op(**op_args)
|
|
1258
1213
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
|
1259
1214
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1260
1215
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
@@ -19,8 +19,6 @@ import torch.nn as nn
|
|
|
19
19
|
from ....utils import logging
|
|
20
20
|
from ...models.decoderonly.decoderonly_architecture import (
|
|
21
21
|
DecoderOnlyAttention,
|
|
22
|
-
DecoderOnlyFlashAttention,
|
|
23
|
-
DecoderOnlyForCausalLM,
|
|
24
22
|
DecoderOnlyLayer,
|
|
25
23
|
DecoderOnlyModel,
|
|
26
24
|
DecoderOnlyWrapper,
|
|
@@ -36,38 +34,23 @@ logger = logging.get_logger(__name__)
|
|
|
36
34
|
class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
37
35
|
"""A wrapper class for the Exaone model with a language modeling head."""
|
|
38
36
|
|
|
39
|
-
def
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
)
|
|
57
|
-
else:
|
|
58
|
-
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
|
59
|
-
|
|
60
|
-
new_layer = ExaoneLayer(layer, new_self_attn)
|
|
61
|
-
new_layers.append(new_layer)
|
|
62
|
-
new_model = ExaoneModel(
|
|
63
|
-
causal_lm.transformer,
|
|
64
|
-
new_layers,
|
|
65
|
-
partition_len=self.kvcache_partition_len,
|
|
66
|
-
max_seq_len=max_seq_len,
|
|
67
|
-
sliding_window_layers=self.sliding_window_layers,
|
|
68
|
-
)
|
|
69
|
-
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
|
70
|
-
return new_causal_lm
|
|
37
|
+
def get_decoder_layers(self, causal_lm: "ExaoneForCausalLM"):
|
|
38
|
+
return causal_lm.transformer.h
|
|
39
|
+
|
|
40
|
+
def get_attn_layer(self, layer: nn.Module):
|
|
41
|
+
return layer.attn.attention
|
|
42
|
+
|
|
43
|
+
def get_model_layer(self, causal_lm: "ExaoneForCausalLM"):
|
|
44
|
+
return causal_lm.transformer
|
|
45
|
+
|
|
46
|
+
def get_rbln_attn_class(self):
|
|
47
|
+
return ExaoneAttention
|
|
48
|
+
|
|
49
|
+
def get_rbln_layer_class(self):
|
|
50
|
+
return ExaoneLayer
|
|
51
|
+
|
|
52
|
+
def get_rbln_model_class(self):
|
|
53
|
+
return ExaoneModel
|
|
71
54
|
|
|
72
55
|
|
|
73
56
|
class ExaoneModel(DecoderOnlyModel):
|
|
@@ -92,11 +75,3 @@ class ExaoneAttention(DecoderOnlyAttention):
|
|
|
92
75
|
self.k_proj = self._original_mod.k_proj
|
|
93
76
|
self.v_proj = self._original_mod.v_proj
|
|
94
77
|
self.o_proj = self._original_mod.out_proj
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
class ExaoneFlashAttention(DecoderOnlyFlashAttention):
|
|
98
|
-
def __post_init__(self):
|
|
99
|
-
self.q_proj = self._original_mod.q_proj
|
|
100
|
-
self.k_proj = self._original_mod.k_proj
|
|
101
|
-
self.v_proj = self._original_mod.v_proj
|
|
102
|
-
self.o_proj = self._original_mod.out_proj
|