optimum-rbln 0.8.1rc1__py3-none-any.whl → 0.8.2a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

@@ -149,6 +149,8 @@ class DecoderOnlyWrapper(nn.Module):
149
149
  This is only relevant if `attn_impl` is set to "flash_attn`
150
150
  """
151
151
 
152
+ _use_learned_pos_emb = False
153
+
152
154
  def __init__(
153
155
  self,
154
156
  causal_lm: PreTrainedModel,
@@ -159,7 +161,6 @@ class DecoderOnlyWrapper(nn.Module):
159
161
  use_inputs_embeds: bool,
160
162
  use_attention_mask: bool,
161
163
  use_position_ids: bool,
162
- use_learned_pos_emb: Optional[bool] = None,
163
164
  kvcache_partition_len: Optional[int] = None,
164
165
  kvcache_block_size: Optional[int] = None,
165
166
  sliding_window: Optional[int] = None,
@@ -182,7 +183,6 @@ class DecoderOnlyWrapper(nn.Module):
182
183
  self.use_attention_mask = use_attention_mask
183
184
  self.use_position_ids = use_position_ids
184
185
  self.use_inputs_embeds = use_inputs_embeds
185
- self.use_learned_pos_emb = use_learned_pos_emb
186
186
  self.sliding_window_layers = sliding_window_layers
187
187
  self.cache_impl = cache_impl
188
188
  self.sliding_window = sliding_window
@@ -207,51 +207,55 @@ class DecoderOnlyWrapper(nn.Module):
207
207
  def get_rotary_emb(self, max_seq_len):
208
208
  return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
209
209
 
210
+ def get_decoder_layers(self, causal_lm: PreTrainedModel):
211
+ return causal_lm.model.layers
212
+
213
+ def get_attn_layer(self, layer: nn.Module):
214
+ return layer.self_attn
215
+
216
+ def get_model_layer(self, causal_lm: PreTrainedModel):
217
+ return causal_lm.model
218
+
219
+ def get_rbln_attn_class(self):
220
+ return DecoderOnlyAttention
221
+
222
+ def get_rbln_layer_class(self):
223
+ return DecoderOnlyLayer
224
+
225
+ def get_rbln_model_class(self):
226
+ return DecoderOnlyModel
227
+
228
+ def get_rbln_causal_lm_class(self):
229
+ return DecoderOnlyForCausalLM
230
+
210
231
  def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
211
232
  new_layers = []
212
- for layer_idx, layer in enumerate(causal_lm.model.layers):
213
- if layer_idx in self.sliding_window_layers:
214
- # Flash attention is not yet supported for sliding window attention.
215
- new_self_attn = DecoderOnlyAttention(
216
- layer.self_attn,
217
- self.use_attention_mask,
218
- self.use_position_ids,
219
- kvcache_block_size=self.sliding_window,
220
- is_sliding=True,
221
- )
222
- else:
223
- if self.attn_impl == "eager":
224
- new_self_attn = DecoderOnlyAttention(
225
- layer.self_attn,
226
- self.use_attention_mask,
227
- self.use_position_ids,
228
- kvcache_block_size=self.kvcache_block_size,
229
- is_sliding=False,
230
- )
231
- elif self.attn_impl == "flash_attn":
232
- new_self_attn = DecoderOnlyFlashAttention(
233
- layer.self_attn,
234
- kvcache_partition_len=self.kvcache_partition_len,
235
- kvcache_block_size=self.kvcache_block_size,
236
- use_attention_mask=self.use_attention_mask,
237
- use_position_ids=self.use_position_ids,
238
- )
239
- else:
240
- raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
241
-
242
- new_layer = DecoderOnlyLayer(layer, new_self_attn)
233
+ for layer_idx, layer in enumerate(self.get_decoder_layers(causal_lm)):
234
+ is_sliding = layer_idx in self.sliding_window_layers
235
+ new_self_attn = self.get_rbln_attn_class()(
236
+ self.get_attn_layer(layer),
237
+ self.use_attention_mask if not is_sliding else True,
238
+ self.use_position_ids,
239
+ kvcache_block_size=self.sliding_window
240
+ if layer_idx in self.sliding_window_layers
241
+ else self.kvcache_block_size,
242
+ is_sliding=is_sliding,
243
+ attn_impl=self.attn_impl if not is_sliding else "eager",
244
+ kvcache_partition_len=self.kvcache_partition_len,
245
+ )
246
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
243
247
  new_layers.append(new_layer)
244
248
 
245
- new_model = DecoderOnlyModel(
246
- causal_lm.model,
249
+ new_model = self.get_rbln_model_class()(
250
+ self.get_model_layer(causal_lm),
247
251
  new_layers,
248
252
  partition_len=self.kvcache_partition_len,
249
253
  max_seq_len=max_seq_len,
250
254
  kvcache_block_size=self.kvcache_block_size,
251
- use_learned_pos_emb=self.use_learned_pos_emb,
255
+ use_learned_pos_emb=self.__class__._use_learned_pos_emb,
252
256
  sliding_window_layers=self.sliding_window_layers,
253
257
  )
254
- new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
258
+ new_causal_lm = self.get_rbln_causal_lm_class()(causal_lm, new_model)
255
259
  return new_causal_lm
256
260
 
257
261
  @property
@@ -679,9 +683,23 @@ class DecoderOnlyAttention(nn.Module):
679
683
 
680
684
  Args:
681
685
  self_attn: Original attention module from the base model
686
+ use_attention_mask: Whether to use attention mask
687
+ use_position_ids: Whether to use position ids
688
+ kvcache_block_size: Block size for KV cache
689
+ is_sliding: Whether this is sliding window attention
690
+ attn_impl: Attention implementation type ("eager" or "flash_attn")
682
691
  """
683
692
 
684
- def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size, is_sliding=False):
693
+ def __init__(
694
+ self,
695
+ self_attn,
696
+ use_attention_mask,
697
+ use_position_ids,
698
+ kvcache_block_size,
699
+ is_sliding=False,
700
+ attn_impl="eager",
701
+ kvcache_partition_len=None,
702
+ ):
685
703
  super().__init__()
686
704
  self._original_mod = self_attn
687
705
  self.layer_idx = self_attn.layer_idx
@@ -702,10 +720,24 @@ class DecoderOnlyAttention(nn.Module):
702
720
  self.use_attention_mask = use_attention_mask
703
721
  self.use_position_ids = use_position_ids
704
722
  self.is_sliding = is_sliding
705
- self.attention = self.get_attention()
723
+ self.attn_impl = attn_impl
724
+ self.kvcache_partition_len = kvcache_partition_len
725
+
726
+ setattr(self, self.get_attention_name(), self.create_attention_op())
706
727
  self.kvcache_block_size = kvcache_block_size
707
728
  self.__post_init__()
708
729
 
730
+ def get_attention_name(self):
731
+ if self.is_sliding:
732
+ return "sliding_window_attention"
733
+ elif self.attn_impl == "flash_attn":
734
+ return "flash_attention"
735
+ else:
736
+ return "attention"
737
+
738
+ def get_attention_op(self):
739
+ return getattr(self, self.get_attention_name())
740
+
709
741
  @property
710
742
  def phase(self):
711
743
  return self._phase
@@ -713,17 +745,36 @@ class DecoderOnlyAttention(nn.Module):
713
745
  @phase.setter
714
746
  def phase(self, phase: str):
715
747
  self._phase = phase
716
- self.attention.phase = phase
748
+ getattr(self, self.get_attention_name()).phase = phase
717
749
 
718
- def get_attention(self):
750
+ def create_attention_op(self):
719
751
  if self.is_sliding:
720
752
  return SlidingWindowAttentionOp(
721
- self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
753
+ self.num_heads,
754
+ self.head_dim,
755
+ self.num_key_value_heads,
756
+ self.use_attention_mask,
757
+ self.use_position_ids,
722
758
  )
723
- else:
759
+ elif self.attn_impl == "flash_attn":
760
+ return FlashAttentionOp(
761
+ self.num_heads,
762
+ self.head_dim,
763
+ self.num_key_value_heads,
764
+ self.kvcache_partition_len,
765
+ self.use_attention_mask,
766
+ self.use_position_ids,
767
+ )
768
+ elif self.attn_impl == "eager":
724
769
  return AttentionOp(
725
- self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
770
+ self.num_heads,
771
+ self.head_dim,
772
+ self.num_key_value_heads,
773
+ self.use_attention_mask,
774
+ self.use_position_ids,
726
775
  )
776
+ else:
777
+ raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
727
778
 
728
779
  def __post_init__(self):
729
780
  self.q_proj = self._original_mod.q_proj
@@ -780,7 +831,7 @@ class DecoderOnlyAttention(nn.Module):
780
831
  if batch_size > 1 and "prefill" in self.phase:
781
832
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
782
833
 
783
- attn_output = self.attention(
834
+ attn_output = self.get_attention_op()(
784
835
  query_states,
785
836
  key_states,
786
837
  value_states,
@@ -797,6 +848,14 @@ class DecoderOnlyAttention(nn.Module):
797
848
  return attn_outputs
798
849
 
799
850
 
851
+ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
852
+ def __init__(self, *args, **kwargs):
853
+ super().__init__(*args, **kwargs)
854
+ logger.warning(
855
+ "DecoderOnlyFlashAttention is deprecated and may not work as expected. Use DecoderOnlyAttention instead."
856
+ )
857
+
858
+
800
859
  class AttentionOp(nn.Module):
801
860
  def __init__(
802
861
  self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
@@ -809,6 +868,18 @@ class AttentionOp(nn.Module):
809
868
  self.use_attention_mask = use_attention_mask
810
869
  self.use_position_ids = use_position_ids
811
870
 
871
+ def get_attn_op_name(self):
872
+ phase = "decode" if self.phase == "decode" else "prefill"
873
+
874
+ if self.use_attention_mask and not self.use_position_ids:
875
+ attn_op_name = "paged_attn_"
876
+ else:
877
+ attn_op_name = "paged_causal_attn_"
878
+
879
+ attn_op_name += phase
880
+
881
+ return attn_op_name
882
+
812
883
  def forward(
813
884
  self,
814
885
  query_state: torch.Tensor,
@@ -857,63 +928,31 @@ class AttentionOp(nn.Module):
857
928
  self.head_dim,
858
929
  )
859
930
 
860
- if self.phase == "decode":
861
- if self.use_attention_mask and not self.use_position_ids:
862
- attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
863
- q=query_state,
864
- k=key_state,
865
- v=value_state,
866
- mask=attn_mask,
867
- kcache=past_key_state.unsqueeze(2),
868
- vcache=past_value_state.unsqueeze(2),
869
- seq=seq_position,
870
- scale=scale,
871
- block_table=block_tables,
872
- block_size=block_size,
873
- )
874
- else:
875
- attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
876
- q=query_state,
877
- k=key_state,
878
- v=value_state,
879
- kcache=past_key_state.unsqueeze(2),
880
- vcache=past_value_state.unsqueeze(2),
881
- seq=seq_position,
882
- scale=scale,
883
- block_table=block_tables,
884
- block_size=block_size,
885
- mask=attn_mask if self.use_position_ids else None,
886
- )
887
-
888
- else:
889
- if self.use_attention_mask and not self.use_position_ids:
890
- attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
891
- q=query_state,
892
- k=key_state,
893
- v=value_state,
894
- mask=attn_mask,
895
- kcache=past_key_state.unsqueeze(2),
896
- vcache=past_value_state.unsqueeze(2),
897
- seq=seq_position,
898
- scale=scale,
899
- block_table=block_tables,
900
- block_size=block_size,
901
- )
902
- else:
903
- attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
904
- q=query_state,
905
- k=key_state,
906
- v=value_state,
907
- kcache=past_key_state.unsqueeze(2),
908
- vcache=past_value_state.unsqueeze(2),
909
- seq=seq_position,
910
- scale=scale,
911
- block_table=block_tables,
912
- block_size=block_size,
913
- is_bidirectional=True if self.phase == "image_prefill" else False, # FIXME, Hard-coded for Gemma3.
914
- mask=attn_mask if self.use_position_ids else None,
915
- )
916
-
931
+ op_args = {
932
+ "q": query_state,
933
+ "k": key_state,
934
+ "v": value_state,
935
+ "kcache": past_key_state.unsqueeze(2),
936
+ "vcache": past_value_state.unsqueeze(2),
937
+ "seq": seq_position,
938
+ "scale": scale,
939
+ "block_table": block_tables,
940
+ "block_size": block_size,
941
+ }
942
+
943
+ if self.use_attention_mask != self.use_position_ids:
944
+ op_args["mask"] = attn_mask
945
+
946
+ if self.phase == "prefill" or self.phase == "image_prefill":
947
+ if not self.use_attention_mask or self.use_position_ids:
948
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
949
+
950
+ attn_op_name = self.get_attn_op_name()
951
+ attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
952
+ if attn_op is None:
953
+ raise ValueError(f"Attention operator {attn_op_name} not found.")
954
+
955
+ attn_output = attn_op(**op_args)
917
956
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
918
957
  attn_output = attn_output.transpose(1, 2).contiguous()
919
958
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
@@ -1012,70 +1051,6 @@ class RotaryEmbedding(nn.Module):
1012
1051
  )
1013
1052
 
1014
1053
 
1015
- class DecoderOnlyFlashAttention(DecoderOnlyAttention):
1016
- def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask, use_position_ids):
1017
- self.kvcache_partition_size = kvcache_partition_len
1018
- super().__init__(
1019
- self_attn=self_attn,
1020
- use_attention_mask=use_attention_mask,
1021
- use_position_ids=use_position_ids,
1022
- kvcache_block_size=kvcache_block_size,
1023
- )
1024
-
1025
- def get_attention(self):
1026
- return FlashAttentionOp(
1027
- self.num_heads,
1028
- self.head_dim,
1029
- self.num_key_value_heads,
1030
- self.kvcache_partition_size,
1031
- self.use_attention_mask,
1032
- self.use_position_ids,
1033
- )
1034
-
1035
- def forward(
1036
- self,
1037
- hidden_states: torch.Tensor,
1038
- attention_mask: torch.Tensor,
1039
- seq_positions: torch.LongTensor,
1040
- past_key_values: Tuple[Tuple[torch.Tensor]],
1041
- cos: Optional[torch.Tensor] = None,
1042
- sin: Optional[torch.Tensor] = None,
1043
- block_tables: Optional[torch.Tensor] = None,
1044
- ):
1045
- batch_size, query_length, _ = hidden_states.size()
1046
-
1047
- query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
1048
-
1049
- query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
1050
- key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1051
- value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
1052
- 1, 2
1053
- )
1054
-
1055
- if hasattr(self, "q_norm") and hasattr(self, "k_norm"):
1056
- query_states = self.q_norm(query_states)
1057
- key_states = self.k_norm(key_states)
1058
-
1059
- if cos is not None and sin is not None:
1060
- query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
1061
-
1062
- attn_output = self.attention(
1063
- query_states,
1064
- key_states,
1065
- value_states,
1066
- attention_mask,
1067
- past_key_state=past_key_values[self.layer_idx][0],
1068
- past_value_state=past_key_values[self.layer_idx][1],
1069
- seq_position=seq_positions,
1070
- scale=self.scale,
1071
- block_tables=block_tables,
1072
- kvcache_block_size=self.kvcache_block_size,
1073
- )
1074
-
1075
- attn_outputs = self.o_proj(attn_output)
1076
- return attn_outputs
1077
-
1078
-
1079
1054
  class FlashAttentionOp(AttentionOp):
1080
1055
  def __init__(
1081
1056
  self,
@@ -1095,6 +1070,17 @@ class FlashAttentionOp(AttentionOp):
1095
1070
  )
1096
1071
  self.kvcache_partition_size = kvcache_partition_len
1097
1072
 
1073
+ def get_attn_op_name(self):
1074
+ phase = "decode" if self.phase == "decode" else "prefill"
1075
+ if self.use_attention_mask and not self.use_position_ids:
1076
+ attn_op_name = "paged_flash_attn_"
1077
+ else:
1078
+ attn_op_name = "paged_flash_causal_attn_"
1079
+
1080
+ attn_op_name += phase
1081
+
1082
+ return attn_op_name
1083
+
1098
1084
  def forward(
1099
1085
  self,
1100
1086
  query_state,
@@ -1106,7 +1092,7 @@ class FlashAttentionOp(AttentionOp):
1106
1092
  seq_position,
1107
1093
  scale,
1108
1094
  block_tables,
1109
- kvcache_block_size,
1095
+ block_size,
1110
1096
  ):
1111
1097
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1112
1098
  key_state = key_state.unsqueeze(2)
@@ -1127,67 +1113,32 @@ class FlashAttentionOp(AttentionOp):
1127
1113
  self.head_dim,
1128
1114
  )
1129
1115
 
1130
- if self.phase == "decode":
1131
- if self.use_attention_mask and not self.use_position_ids:
1132
- attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1133
- q=query_state,
1134
- k=key_state,
1135
- v=value_state,
1136
- mask=attn_mask,
1137
- kcache=past_key_state.unsqueeze(2),
1138
- vcache=past_value_state.unsqueeze(2),
1139
- seq=seq_position,
1140
- scale=scale,
1141
- block_table=block_tables,
1142
- block_size=kvcache_block_size,
1143
- partition=self.kvcache_partition_size,
1144
- )
1145
- else:
1146
- attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
1147
- q=query_state,
1148
- k=key_state,
1149
- v=value_state,
1150
- kcache=past_key_state.unsqueeze(2),
1151
- vcache=past_value_state.unsqueeze(2),
1152
- seq=seq_position,
1153
- scale=scale,
1154
- block_table=block_tables,
1155
- block_size=kvcache_block_size,
1156
- partition=self.kvcache_partition_size,
1157
- mask=attn_mask if self.use_position_ids else None,
1158
- )
1159
- else:
1160
- if self.use_attention_mask and not self.use_position_ids:
1161
- attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1162
- q=query_state,
1163
- k=key_state,
1164
- v=value_state,
1165
- mask=attn_mask,
1166
- kcache=past_key_state.unsqueeze(2),
1167
- vcache=past_value_state.unsqueeze(2),
1168
- seq=seq_position,
1169
- scale=scale,
1170
- block_table=block_tables,
1171
- block_size=kvcache_block_size,
1172
- partition=self.kvcache_partition_size,
1173
- )
1174
- else:
1175
- attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
1176
- q=query_state,
1177
- k=key_state,
1178
- v=value_state,
1179
- kcache=past_key_state.unsqueeze(2),
1180
- vcache=past_value_state.unsqueeze(2),
1181
- seq=seq_position,
1182
- scale=scale,
1183
- block_table=block_tables,
1184
- block_size=kvcache_block_size,
1185
- partition=self.kvcache_partition_size,
1186
- is_bidirectional=True if self.phase == "image_prefill" else False,
1187
- mask=attn_mask if self.use_position_ids else None,
1188
- )
1189
-
1190
- # reshape for removing repeat_kv
1116
+ op_args = {
1117
+ "q": query_state,
1118
+ "k": key_state,
1119
+ "v": value_state,
1120
+ "kcache": past_key_state.unsqueeze(2),
1121
+ "vcache": past_value_state.unsqueeze(2),
1122
+ "seq": seq_position,
1123
+ "scale": scale,
1124
+ "block_table": block_tables,
1125
+ "block_size": block_size,
1126
+ "partition": self.kvcache_partition_size,
1127
+ }
1128
+
1129
+ if self.use_attention_mask:
1130
+ op_args["mask"] = attn_mask
1131
+
1132
+ if self.phase == "prefill" or self.phase == "image_prefill":
1133
+ if not self.use_attention_mask or self.use_position_ids:
1134
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1135
+
1136
+ attn_op_name = self.get_attn_op_name()
1137
+ attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
1138
+ if attn_op is None:
1139
+ raise ValueError(f"Attention operator {attn_op_name} not found.")
1140
+
1141
+ attn_output = attn_op(**op_args)
1191
1142
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
1192
1143
  attn_output = attn_output.transpose(1, 2).contiguous()
1193
1144
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
@@ -1196,6 +1147,14 @@ class FlashAttentionOp(AttentionOp):
1196
1147
 
1197
1148
 
1198
1149
  class SlidingWindowAttentionOp(AttentionOp):
1150
+ def get_attn_op_name(self):
1151
+ phase = "decode" if self.phase == "decode" else "prefill"
1152
+ if not self.use_attention_mask:
1153
+ raise NotImplementedError("Attention mask is needed for sliding window attention.")
1154
+
1155
+ attn_op_name = "paged_sliding_window_attn_" + phase
1156
+ return attn_op_name
1157
+
1199
1158
  def forward(
1200
1159
  self,
1201
1160
  query_state: torch.Tensor,
@@ -1226,35 +1185,29 @@ class SlidingWindowAttentionOp(AttentionOp):
1226
1185
  self.head_dim,
1227
1186
  )
1228
1187
 
1229
- if self.phase == "decode":
1230
- attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_decode(
1231
- q=query_state,
1232
- k=key_state,
1233
- v=value_state,
1234
- kcache=past_key_state.unsqueeze(2),
1235
- vcache=past_value_state.unsqueeze(2),
1236
- cache_seq_len=seq_position[0],
1237
- cache_offset=seq_position[1],
1238
- scale=scale,
1239
- block_table=block_tables,
1240
- block_size=block_size,
1241
- )
1242
- else:
1243
- attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_prefill(
1244
- q=query_state,
1245
- k=key_state,
1246
- v=value_state,
1247
- kcache=past_key_state.unsqueeze(2),
1248
- vcache=past_value_state.unsqueeze(2),
1249
- cache_seq_len=seq_position[0],
1250
- cache_offset=seq_position[1],
1251
- scale=scale,
1252
- block_table=block_tables,
1253
- block_size=block_size,
1254
- is_bidirectional=True if self.phase == "image_prefill" else False,
1255
- )
1256
-
1257
- # reshape for removing repeat_kv
1188
+ op_args = {
1189
+ "q": query_state,
1190
+ "k": key_state,
1191
+ "v": value_state,
1192
+ "kcache": past_key_state.unsqueeze(2),
1193
+ "vcache": past_value_state.unsqueeze(2),
1194
+ "cache_seq_len": seq_position[0],
1195
+ "cache_offset": seq_position[1],
1196
+ "scale": scale,
1197
+ "block_table": block_tables,
1198
+ "block_size": block_size,
1199
+ }
1200
+
1201
+ if self.phase == "prefill" or self.phase == "image_prefill":
1202
+ if not self.use_attention_mask or self.use_position_ids:
1203
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1204
+
1205
+ attn_op_name = self.get_attn_op_name()
1206
+ attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
1207
+ if attn_op is None:
1208
+ raise ValueError(f"Attention operator {attn_op_name} not found.")
1209
+
1210
+ attn_output = attn_op(**op_args)
1258
1211
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
1259
1212
  attn_output = attn_output.transpose(1, 2).contiguous()
1260
1213
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
@@ -19,8 +19,6 @@ import torch.nn as nn
19
19
  from ....utils import logging
20
20
  from ...models.decoderonly.decoderonly_architecture import (
21
21
  DecoderOnlyAttention,
22
- DecoderOnlyFlashAttention,
23
- DecoderOnlyForCausalLM,
24
22
  DecoderOnlyLayer,
25
23
  DecoderOnlyModel,
26
24
  DecoderOnlyWrapper,
@@ -36,38 +34,23 @@ logger = logging.get_logger(__name__)
36
34
  class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
37
35
  """A wrapper class for the Exaone model with a language modeling head."""
38
36
 
39
- def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM", max_seq_len: int):
40
- new_layers = []
41
- for layer in causal_lm.transformer.h:
42
- if self.attn_impl == "eager":
43
- new_self_attn = ExaoneAttention(
44
- layer.attn.attention,
45
- self.use_attention_mask,
46
- kvcache_block_size=self.kvcache_block_size,
47
- use_position_ids=self.use_position_ids,
48
- )
49
- elif self.attn_impl == "flash_attn":
50
- new_self_attn = ExaoneFlashAttention(
51
- layer.attn.attention,
52
- kvcache_partition_len=self.kvcache_partition_len,
53
- use_attention_mask=self.use_attention_mask,
54
- kvcache_block_size=self.kvcache_block_size,
55
- use_position_ids=self.use_position_ids,
56
- )
57
- else:
58
- raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
59
-
60
- new_layer = ExaoneLayer(layer, new_self_attn)
61
- new_layers.append(new_layer)
62
- new_model = ExaoneModel(
63
- causal_lm.transformer,
64
- new_layers,
65
- partition_len=self.kvcache_partition_len,
66
- max_seq_len=max_seq_len,
67
- sliding_window_layers=self.sliding_window_layers,
68
- )
69
- new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
70
- return new_causal_lm
37
+ def get_decoder_layers(self, causal_lm: "ExaoneForCausalLM"):
38
+ return causal_lm.transformer.h
39
+
40
+ def get_attn_layer(self, layer: nn.Module):
41
+ return layer.attn.attention
42
+
43
+ def get_model_layer(self, causal_lm: "ExaoneForCausalLM"):
44
+ return causal_lm.transformer
45
+
46
+ def get_rbln_attn_class(self):
47
+ return ExaoneAttention
48
+
49
+ def get_rbln_layer_class(self):
50
+ return ExaoneLayer
51
+
52
+ def get_rbln_model_class(self):
53
+ return ExaoneModel
71
54
 
72
55
 
73
56
  class ExaoneModel(DecoderOnlyModel):
@@ -92,11 +75,3 @@ class ExaoneAttention(DecoderOnlyAttention):
92
75
  self.k_proj = self._original_mod.k_proj
93
76
  self.v_proj = self._original_mod.v_proj
94
77
  self.o_proj = self._original_mod.out_proj
95
-
96
-
97
- class ExaoneFlashAttention(DecoderOnlyFlashAttention):
98
- def __post_init__(self):
99
- self.q_proj = self._original_mod.q_proj
100
- self.k_proj = self._original_mod.k_proj
101
- self.v_proj = self._original_mod.v_proj
102
- self.o_proj = self._original_mod.out_proj