optimum-rbln 0.8.1rc0__py3-none-any.whl → 0.8.2a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. optimum/rbln/__version__.py +2 -2
  2. optimum/rbln/diffusers/configurations/models/__init__.py +1 -1
  3. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +5 -3
  4. optimum/rbln/diffusers/configurations/models/{configuration_cosmos_transformer.py → configuration_transformer_cosmos.py} +4 -2
  5. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +10 -6
  6. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +5 -1
  7. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  8. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  9. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  10. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  11. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +212 -257
  12. optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
  13. optimum/rbln/transformers/models/gemma/gemma_architecture.py +2 -40
  14. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
  15. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +18 -22
  16. optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
  17. optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
  18. optimum/rbln/transformers/models/phi/phi_architecture.py +14 -20
  19. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2a0.dist-info}/METADATA +1 -1
  20. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2a0.dist-info}/RECORD +22 -22
  21. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2a0.dist-info}/WHEEL +0 -0
  22. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2a0.dist-info}/licenses/LICENSE +0 -0
@@ -149,6 +149,8 @@ class DecoderOnlyWrapper(nn.Module):
149
149
  This is only relevant if `attn_impl` is set to "flash_attn`
150
150
  """
151
151
 
152
+ _use_learned_pos_emb = False
153
+
152
154
  def __init__(
153
155
  self,
154
156
  causal_lm: PreTrainedModel,
@@ -159,7 +161,6 @@ class DecoderOnlyWrapper(nn.Module):
159
161
  use_inputs_embeds: bool,
160
162
  use_attention_mask: bool,
161
163
  use_position_ids: bool,
162
- use_learned_pos_emb: Optional[bool] = None,
163
164
  kvcache_partition_len: Optional[int] = None,
164
165
  kvcache_block_size: Optional[int] = None,
165
166
  sliding_window: Optional[int] = None,
@@ -182,7 +183,6 @@ class DecoderOnlyWrapper(nn.Module):
182
183
  self.use_attention_mask = use_attention_mask
183
184
  self.use_position_ids = use_position_ids
184
185
  self.use_inputs_embeds = use_inputs_embeds
185
- self.use_learned_pos_emb = use_learned_pos_emb
186
186
  self.sliding_window_layers = sliding_window_layers
187
187
  self.cache_impl = cache_impl
188
188
  self.sliding_window = sliding_window
@@ -207,51 +207,54 @@ class DecoderOnlyWrapper(nn.Module):
207
207
  def get_rotary_emb(self, max_seq_len):
208
208
  return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
209
209
 
210
+ def get_decoder_layers(self, causal_lm: PreTrainedModel):
211
+ return causal_lm.model.layers
212
+
213
+ def get_attn_layer(self, layer: nn.Module):
214
+ return layer.self_attn
215
+
216
+ def get_model_layer(self, causal_lm: PreTrainedModel):
217
+ return causal_lm.model
218
+
219
+ def get_rbln_attn_class(self):
220
+ return DecoderOnlyAttention
221
+
222
+ def get_rbln_layer_class(self):
223
+ return DecoderOnlyLayer
224
+
225
+ def get_rbln_model_class(self):
226
+ return DecoderOnlyModel
227
+
228
+ def get_rbln_causal_lm_class(self):
229
+ return DecoderOnlyForCausalLM
230
+
210
231
  def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
211
232
  new_layers = []
212
- for layer_idx, layer in enumerate(causal_lm.model.layers):
213
- if layer_idx in self.sliding_window_layers:
214
- # Flash attention is not yet supported for sliding window attention.
215
- new_self_attn = DecoderOnlyAttention(
216
- layer.self_attn,
217
- self.use_attention_mask,
218
- self.use_position_ids,
219
- kvcache_block_size=self.sliding_window,
220
- is_sliding=True,
221
- )
222
- else:
223
- if self.attn_impl == "eager":
224
- new_self_attn = DecoderOnlyAttention(
225
- layer.self_attn,
226
- self.use_attention_mask,
227
- self.use_position_ids,
228
- kvcache_block_size=self.kvcache_block_size,
229
- is_sliding=False,
230
- )
231
- elif self.attn_impl == "flash_attn":
232
- new_self_attn = DecoderOnlyFlashAttention(
233
- layer.self_attn,
234
- kvcache_partition_len=self.kvcache_partition_len,
235
- kvcache_block_size=self.kvcache_block_size,
236
- use_attention_mask=self.use_attention_mask,
237
- use_position_ids=self.use_position_ids,
238
- )
239
- else:
240
- raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
241
-
242
- new_layer = DecoderOnlyLayer(layer, new_self_attn)
233
+ for layer_idx, layer in enumerate(self.get_decoder_layers(causal_lm)):
234
+ new_self_attn = self.get_rbln_attn_class()(
235
+ self.get_attn_layer(layer),
236
+ self.use_attention_mask,
237
+ self.use_position_ids,
238
+ kvcache_block_size=self.sliding_window
239
+ if layer_idx in self.sliding_window_layers
240
+ else self.kvcache_block_size,
241
+ is_sliding=layer_idx in self.sliding_window_layers,
242
+ attn_impl=self.attn_impl,
243
+ kvcache_partition_len=self.kvcache_partition_len,
244
+ )
245
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
243
246
  new_layers.append(new_layer)
244
247
 
245
- new_model = DecoderOnlyModel(
246
- causal_lm.model,
248
+ new_model = self.get_rbln_model_class()(
249
+ self.get_model_layer(causal_lm),
247
250
  new_layers,
248
251
  partition_len=self.kvcache_partition_len,
249
252
  max_seq_len=max_seq_len,
250
253
  kvcache_block_size=self.kvcache_block_size,
251
- use_learned_pos_emb=self.use_learned_pos_emb,
254
+ use_learned_pos_emb=self.__class__._use_learned_pos_emb,
252
255
  sliding_window_layers=self.sliding_window_layers,
253
256
  )
254
- new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
257
+ new_causal_lm = self.get_rbln_causal_lm_class()(causal_lm, new_model)
255
258
  return new_causal_lm
256
259
 
257
260
  @property
@@ -679,9 +682,23 @@ class DecoderOnlyAttention(nn.Module):
679
682
 
680
683
  Args:
681
684
  self_attn: Original attention module from the base model
685
+ use_attention_mask: Whether to use attention mask
686
+ use_position_ids: Whether to use position ids
687
+ kvcache_block_size: Block size for KV cache
688
+ is_sliding: Whether this is sliding window attention
689
+ attn_impl: Attention implementation type ("eager" or "flash_attn")
682
690
  """
683
691
 
684
- def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size, is_sliding=False):
692
+ def __init__(
693
+ self,
694
+ self_attn,
695
+ use_attention_mask,
696
+ use_position_ids,
697
+ kvcache_block_size,
698
+ is_sliding=False,
699
+ attn_impl="eager",
700
+ kvcache_partition_len=None,
701
+ ):
685
702
  super().__init__()
686
703
  self._original_mod = self_attn
687
704
  self.layer_idx = self_attn.layer_idx
@@ -702,10 +719,28 @@ class DecoderOnlyAttention(nn.Module):
702
719
  self.use_attention_mask = use_attention_mask
703
720
  self.use_position_ids = use_position_ids
704
721
  self.is_sliding = is_sliding
705
- self.attention = self.get_attention()
722
+ self.attn_impl = attn_impl
723
+
724
+ if self.is_sliding and self.attn_impl != "eager":
725
+ raise NotImplementedError("Sliding window attention is only supported with eager attention.")
726
+
727
+ self.kvcache_partition_len = kvcache_partition_len
728
+
729
+ setattr(self, self.get_attention_name(), self.create_attention_op())
706
730
  self.kvcache_block_size = kvcache_block_size
707
731
  self.__post_init__()
708
732
 
733
+ def get_attention_name(self):
734
+ if self.is_sliding:
735
+ return "sliding_window_attention"
736
+ elif self.attn_impl == "flash_attn":
737
+ return "flash_attention"
738
+ else:
739
+ return "attention"
740
+
741
+ def get_attention_op(self):
742
+ return getattr(self, self.get_attention_name())
743
+
709
744
  @property
710
745
  def phase(self):
711
746
  return self._phase
@@ -713,17 +748,36 @@ class DecoderOnlyAttention(nn.Module):
713
748
  @phase.setter
714
749
  def phase(self, phase: str):
715
750
  self._phase = phase
716
- self.attention.phase = phase
751
+ getattr(self, self.get_attention_name()).phase = phase
717
752
 
718
- def get_attention(self):
753
+ def create_attention_op(self):
719
754
  if self.is_sliding:
720
755
  return SlidingWindowAttentionOp(
721
- self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
756
+ self.num_heads,
757
+ self.head_dim,
758
+ self.num_key_value_heads,
759
+ self.use_attention_mask,
760
+ self.use_position_ids,
722
761
  )
723
- else:
762
+ elif self.attn_impl == "flash_attn":
763
+ return FlashAttentionOp(
764
+ self.num_heads,
765
+ self.head_dim,
766
+ self.num_key_value_heads,
767
+ self.kvcache_partition_len,
768
+ self.use_attention_mask,
769
+ self.use_position_ids,
770
+ )
771
+ elif self.attn_impl == "eager":
724
772
  return AttentionOp(
725
- self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
773
+ self.num_heads,
774
+ self.head_dim,
775
+ self.num_key_value_heads,
776
+ self.use_attention_mask,
777
+ self.use_position_ids,
726
778
  )
779
+ else:
780
+ raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
727
781
 
728
782
  def __post_init__(self):
729
783
  self.q_proj = self._original_mod.q_proj
@@ -780,7 +834,7 @@ class DecoderOnlyAttention(nn.Module):
780
834
  if batch_size > 1 and "prefill" in self.phase:
781
835
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
782
836
 
783
- attn_output = self.attention(
837
+ attn_output = self.get_attention_op()(
784
838
  query_states,
785
839
  key_states,
786
840
  value_states,
@@ -797,6 +851,14 @@ class DecoderOnlyAttention(nn.Module):
797
851
  return attn_outputs
798
852
 
799
853
 
854
+ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
855
+ def __init__(self, *args, **kwargs):
856
+ super().__init__(*args, **kwargs)
857
+ logger.warning(
858
+ "DecoderOnlyFlashAttention is deprecated and may not work as expected. Use DecoderOnlyAttention instead."
859
+ )
860
+
861
+
800
862
  class AttentionOp(nn.Module):
801
863
  def __init__(
802
864
  self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
@@ -809,6 +871,17 @@ class AttentionOp(nn.Module):
809
871
  self.use_attention_mask = use_attention_mask
810
872
  self.use_position_ids = use_position_ids
811
873
 
874
+ def get_attn_op_name(self):
875
+ phase = "decode" if self.phase == "decode" else "prefill"
876
+ if self.use_attention_mask:
877
+ attn_op_name = "paged_attn_"
878
+ else:
879
+ attn_op_name = "paged_causal_attn_"
880
+
881
+ attn_op_name += phase
882
+
883
+ return attn_op_name
884
+
812
885
  def forward(
813
886
  self,
814
887
  query_state: torch.Tensor,
@@ -857,63 +930,31 @@ class AttentionOp(nn.Module):
857
930
  self.head_dim,
858
931
  )
859
932
 
860
- if self.phase == "decode":
861
- if self.use_attention_mask and not self.use_position_ids:
862
- attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
863
- q=query_state,
864
- k=key_state,
865
- v=value_state,
866
- mask=attn_mask,
867
- kcache=past_key_state.unsqueeze(2),
868
- vcache=past_value_state.unsqueeze(2),
869
- seq=seq_position,
870
- scale=scale,
871
- block_table=block_tables,
872
- block_size=block_size,
873
- )
874
- else:
875
- attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
876
- q=query_state,
877
- k=key_state,
878
- v=value_state,
879
- kcache=past_key_state.unsqueeze(2),
880
- vcache=past_value_state.unsqueeze(2),
881
- seq=seq_position,
882
- scale=scale,
883
- block_table=block_tables,
884
- block_size=block_size,
885
- mask=attn_mask if self.use_position_ids else None,
886
- )
887
-
888
- else:
889
- if self.use_attention_mask and not self.use_position_ids:
890
- attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
891
- q=query_state,
892
- k=key_state,
893
- v=value_state,
894
- mask=attn_mask,
895
- kcache=past_key_state.unsqueeze(2),
896
- vcache=past_value_state.unsqueeze(2),
897
- seq=seq_position,
898
- scale=scale,
899
- block_table=block_tables,
900
- block_size=block_size,
901
- )
902
- else:
903
- attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
904
- q=query_state,
905
- k=key_state,
906
- v=value_state,
907
- kcache=past_key_state.unsqueeze(2),
908
- vcache=past_value_state.unsqueeze(2),
909
- seq=seq_position,
910
- scale=scale,
911
- block_table=block_tables,
912
- block_size=block_size,
913
- is_bidirectional=True if self.phase == "image_prefill" else False, # FIXME, Hard-coded for Gemma3.
914
- mask=attn_mask if self.use_position_ids else None,
915
- )
916
-
933
+ op_args = {
934
+ "q": query_state,
935
+ "k": key_state,
936
+ "v": value_state,
937
+ "kcache": past_key_state.unsqueeze(2),
938
+ "vcache": past_value_state.unsqueeze(2),
939
+ "seq": seq_position,
940
+ "scale": scale,
941
+ "block_table": block_tables,
942
+ "block_size": block_size,
943
+ }
944
+
945
+ if self.use_attention_mask != self.use_position_ids:
946
+ op_args["mask"] = attn_mask
947
+
948
+ if self.phase == "prefill" or self.phase == "image_prefill":
949
+ if not self.use_attention_mask or self.use_position_ids:
950
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
951
+
952
+ attn_op_name = self.get_attn_op_name()
953
+ attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
954
+ if attn_op is None:
955
+ raise ValueError(f"Attention operator {attn_op_name} not found.")
956
+
957
+ attn_output = attn_op(**op_args)
917
958
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
918
959
  attn_output = attn_output.transpose(1, 2).contiguous()
919
960
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
@@ -1012,70 +1053,6 @@ class RotaryEmbedding(nn.Module):
1012
1053
  )
1013
1054
 
1014
1055
 
1015
- class DecoderOnlyFlashAttention(DecoderOnlyAttention):
1016
- def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask, use_position_ids):
1017
- self.kvcache_partition_size = kvcache_partition_len
1018
- super().__init__(
1019
- self_attn=self_attn,
1020
- use_attention_mask=use_attention_mask,
1021
- use_position_ids=use_position_ids,
1022
- kvcache_block_size=kvcache_block_size,
1023
- )
1024
-
1025
- def get_attention(self):
1026
- return FlashAttentionOp(
1027
- self.num_heads,
1028
- self.head_dim,
1029
- self.num_key_value_heads,
1030
- self.kvcache_partition_size,
1031
- self.use_attention_mask,
1032
- self.use_position_ids,
1033
- )
1034
-
1035
- def forward(
1036
- self,
1037
- hidden_states: torch.Tensor,
1038
- attention_mask: torch.Tensor,
1039
- seq_positions: torch.LongTensor,
1040
- past_key_values: Tuple[Tuple[torch.Tensor]],
1041
- cos: Optional[torch.Tensor] = None,
1042
- sin: Optional[torch.Tensor] = None,
1043
- block_tables: Optional[torch.Tensor] = None,
1044
- ):
1045
- batch_size, query_length, _ = hidden_states.size()
1046
-
1047
- query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
1048
-
1049
- query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
1050
- key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1051
- value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
1052
- 1, 2
1053
- )
1054
-
1055
- if hasattr(self, "q_norm") and hasattr(self, "k_norm"):
1056
- query_states = self.q_norm(query_states)
1057
- key_states = self.k_norm(key_states)
1058
-
1059
- if cos is not None and sin is not None:
1060
- query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
1061
-
1062
- attn_output = self.attention(
1063
- query_states,
1064
- key_states,
1065
- value_states,
1066
- attention_mask,
1067
- past_key_state=past_key_values[self.layer_idx][0],
1068
- past_value_state=past_key_values[self.layer_idx][1],
1069
- seq_position=seq_positions,
1070
- scale=self.scale,
1071
- block_tables=block_tables,
1072
- kvcache_block_size=self.kvcache_block_size,
1073
- )
1074
-
1075
- attn_outputs = self.o_proj(attn_output)
1076
- return attn_outputs
1077
-
1078
-
1079
1056
  class FlashAttentionOp(AttentionOp):
1080
1057
  def __init__(
1081
1058
  self,
@@ -1095,6 +1072,17 @@ class FlashAttentionOp(AttentionOp):
1095
1072
  )
1096
1073
  self.kvcache_partition_size = kvcache_partition_len
1097
1074
 
1075
+ def get_attn_op_name(self):
1076
+ phase = "decode" if self.phase == "decode" else "prefill"
1077
+ if self.use_attention_mask:
1078
+ attn_op_name = "paged_flash_attn_"
1079
+ else:
1080
+ attn_op_name = "paged_flash_causal_attn_"
1081
+
1082
+ attn_op_name += phase
1083
+
1084
+ return attn_op_name
1085
+
1098
1086
  def forward(
1099
1087
  self,
1100
1088
  query_state,
@@ -1106,7 +1094,7 @@ class FlashAttentionOp(AttentionOp):
1106
1094
  seq_position,
1107
1095
  scale,
1108
1096
  block_tables,
1109
- kvcache_block_size,
1097
+ block_size,
1110
1098
  ):
1111
1099
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1112
1100
  key_state = key_state.unsqueeze(2)
@@ -1127,67 +1115,32 @@ class FlashAttentionOp(AttentionOp):
1127
1115
  self.head_dim,
1128
1116
  )
1129
1117
 
1130
- if self.phase == "decode":
1131
- if self.use_attention_mask and not self.use_position_ids:
1132
- attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1133
- q=query_state,
1134
- k=key_state,
1135
- v=value_state,
1136
- mask=attn_mask,
1137
- kcache=past_key_state.unsqueeze(2),
1138
- vcache=past_value_state.unsqueeze(2),
1139
- seq=seq_position,
1140
- scale=scale,
1141
- block_table=block_tables,
1142
- block_size=kvcache_block_size,
1143
- partition=self.kvcache_partition_size,
1144
- )
1145
- else:
1146
- attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
1147
- q=query_state,
1148
- k=key_state,
1149
- v=value_state,
1150
- kcache=past_key_state.unsqueeze(2),
1151
- vcache=past_value_state.unsqueeze(2),
1152
- seq=seq_position,
1153
- scale=scale,
1154
- block_table=block_tables,
1155
- block_size=kvcache_block_size,
1156
- partition=self.kvcache_partition_size,
1157
- mask=attn_mask if self.use_position_ids else None,
1158
- )
1159
- else:
1160
- if self.use_attention_mask and not self.use_position_ids:
1161
- attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1162
- q=query_state,
1163
- k=key_state,
1164
- v=value_state,
1165
- mask=attn_mask,
1166
- kcache=past_key_state.unsqueeze(2),
1167
- vcache=past_value_state.unsqueeze(2),
1168
- seq=seq_position,
1169
- scale=scale,
1170
- block_table=block_tables,
1171
- block_size=kvcache_block_size,
1172
- partition=self.kvcache_partition_size,
1173
- )
1174
- else:
1175
- attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
1176
- q=query_state,
1177
- k=key_state,
1178
- v=value_state,
1179
- kcache=past_key_state.unsqueeze(2),
1180
- vcache=past_value_state.unsqueeze(2),
1181
- seq=seq_position,
1182
- scale=scale,
1183
- block_table=block_tables,
1184
- block_size=kvcache_block_size,
1185
- partition=self.kvcache_partition_size,
1186
- is_bidirectional=True if self.phase == "image_prefill" else False,
1187
- mask=attn_mask if self.use_position_ids else None,
1188
- )
1189
-
1190
- # reshape for removing repeat_kv
1118
+ op_args = {
1119
+ "q": query_state,
1120
+ "k": key_state,
1121
+ "v": value_state,
1122
+ "kcache": past_key_state.unsqueeze(2),
1123
+ "vcache": past_value_state.unsqueeze(2),
1124
+ "seq": seq_position,
1125
+ "scale": scale,
1126
+ "block_table": block_tables,
1127
+ "block_size": block_size,
1128
+ "partition": self.kvcache_partition_size,
1129
+ }
1130
+
1131
+ if self.use_attention_mask != self.use_position_ids:
1132
+ op_args["mask"] = attn_mask
1133
+
1134
+ if self.phase == "prefill" or self.phase == "image_prefill":
1135
+ if not self.use_attention_mask or self.use_position_ids:
1136
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1137
+
1138
+ attn_op_name = self.get_attn_op_name()
1139
+ attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
1140
+ if attn_op is None:
1141
+ raise ValueError(f"Attention operator {attn_op_name} not found.")
1142
+
1143
+ attn_output = attn_op(**op_args)
1191
1144
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
1192
1145
  attn_output = attn_output.transpose(1, 2).contiguous()
1193
1146
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
@@ -1196,6 +1149,14 @@ class FlashAttentionOp(AttentionOp):
1196
1149
 
1197
1150
 
1198
1151
  class SlidingWindowAttentionOp(AttentionOp):
1152
+ def get_attn_op_name(self):
1153
+ phase = "decode" if self.phase == "decode" else "prefill"
1154
+ if self.use_attention_mask:
1155
+ raise NotImplementedError("Attention mask is not supported for sliding window attention.")
1156
+
1157
+ attn_op_name = "paged_sliding_window_attn_" + phase
1158
+ return attn_op_name
1159
+
1199
1160
  def forward(
1200
1161
  self,
1201
1162
  query_state: torch.Tensor,
@@ -1226,35 +1187,29 @@ class SlidingWindowAttentionOp(AttentionOp):
1226
1187
  self.head_dim,
1227
1188
  )
1228
1189
 
1229
- if self.phase == "decode":
1230
- attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_decode(
1231
- q=query_state,
1232
- k=key_state,
1233
- v=value_state,
1234
- kcache=past_key_state.unsqueeze(2),
1235
- vcache=past_value_state.unsqueeze(2),
1236
- cache_seq_len=seq_position[0],
1237
- cache_offset=seq_position[1],
1238
- scale=scale,
1239
- block_table=block_tables,
1240
- block_size=block_size,
1241
- )
1242
- else:
1243
- attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_prefill(
1244
- q=query_state,
1245
- k=key_state,
1246
- v=value_state,
1247
- kcache=past_key_state.unsqueeze(2),
1248
- vcache=past_value_state.unsqueeze(2),
1249
- cache_seq_len=seq_position[0],
1250
- cache_offset=seq_position[1],
1251
- scale=scale,
1252
- block_table=block_tables,
1253
- block_size=block_size,
1254
- is_bidirectional=True if self.phase == "image_prefill" else False,
1255
- )
1256
-
1257
- # reshape for removing repeat_kv
1190
+ op_args = {
1191
+ "q": query_state,
1192
+ "k": key_state,
1193
+ "v": value_state,
1194
+ "kcache": past_key_state.unsqueeze(2),
1195
+ "vcache": past_value_state.unsqueeze(2),
1196
+ "cache_seq_len": seq_position[0],
1197
+ "cache_offset": seq_position[1],
1198
+ "scale": scale,
1199
+ "block_table": block_tables,
1200
+ "block_size": block_size,
1201
+ }
1202
+
1203
+ if self.phase == "prefill" or self.phase == "image_prefill":
1204
+ if not self.use_attention_mask or self.use_position_ids:
1205
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1206
+
1207
+ attn_op_name = self.get_attn_op_name()
1208
+ attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
1209
+ if attn_op is None:
1210
+ raise ValueError(f"Attention operator {attn_op_name} not found.")
1211
+
1212
+ attn_output = attn_op(**op_args)
1258
1213
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
1259
1214
  attn_output = attn_output.transpose(1, 2).contiguous()
1260
1215
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
@@ -19,8 +19,6 @@ import torch.nn as nn
19
19
  from ....utils import logging
20
20
  from ...models.decoderonly.decoderonly_architecture import (
21
21
  DecoderOnlyAttention,
22
- DecoderOnlyFlashAttention,
23
- DecoderOnlyForCausalLM,
24
22
  DecoderOnlyLayer,
25
23
  DecoderOnlyModel,
26
24
  DecoderOnlyWrapper,
@@ -36,38 +34,23 @@ logger = logging.get_logger(__name__)
36
34
  class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
37
35
  """A wrapper class for the Exaone model with a language modeling head."""
38
36
 
39
- def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM", max_seq_len: int):
40
- new_layers = []
41
- for layer in causal_lm.transformer.h:
42
- if self.attn_impl == "eager":
43
- new_self_attn = ExaoneAttention(
44
- layer.attn.attention,
45
- self.use_attention_mask,
46
- kvcache_block_size=self.kvcache_block_size,
47
- use_position_ids=self.use_position_ids,
48
- )
49
- elif self.attn_impl == "flash_attn":
50
- new_self_attn = ExaoneFlashAttention(
51
- layer.attn.attention,
52
- kvcache_partition_len=self.kvcache_partition_len,
53
- use_attention_mask=self.use_attention_mask,
54
- kvcache_block_size=self.kvcache_block_size,
55
- use_position_ids=self.use_position_ids,
56
- )
57
- else:
58
- raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
59
-
60
- new_layer = ExaoneLayer(layer, new_self_attn)
61
- new_layers.append(new_layer)
62
- new_model = ExaoneModel(
63
- causal_lm.transformer,
64
- new_layers,
65
- partition_len=self.kvcache_partition_len,
66
- max_seq_len=max_seq_len,
67
- sliding_window_layers=self.sliding_window_layers,
68
- )
69
- new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
70
- return new_causal_lm
37
+ def get_decoder_layers(self, causal_lm: "ExaoneForCausalLM"):
38
+ return causal_lm.transformer.h
39
+
40
+ def get_attn_layer(self, layer: nn.Module):
41
+ return layer.attn.attention
42
+
43
+ def get_model_layer(self, causal_lm: "ExaoneForCausalLM"):
44
+ return causal_lm.transformer
45
+
46
+ def get_rbln_attn_class(self):
47
+ return ExaoneAttention
48
+
49
+ def get_rbln_layer_class(self):
50
+ return ExaoneLayer
51
+
52
+ def get_rbln_model_class(self):
53
+ return ExaoneModel
71
54
 
72
55
 
73
56
  class ExaoneModel(DecoderOnlyModel):
@@ -92,11 +75,3 @@ class ExaoneAttention(DecoderOnlyAttention):
92
75
  self.k_proj = self._original_mod.k_proj
93
76
  self.v_proj = self._original_mod.v_proj
94
77
  self.o_proj = self._original_mod.out_proj
95
-
96
-
97
- class ExaoneFlashAttention(DecoderOnlyFlashAttention):
98
- def __post_init__(self):
99
- self.q_proj = self._original_mod.q_proj
100
- self.k_proj = self._original_mod.k_proj
101
- self.v_proj = self._original_mod.v_proj
102
- self.o_proj = self._original_mod.out_proj