optimum-rbln 0.7.5a1__py3-none-any.whl → 0.7.5rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. optimum/rbln/__init__.py +10 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/transformers/__init__.py +10 -0
  4. optimum/rbln/transformers/models/__init__.py +14 -0
  5. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  6. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  7. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +114 -19
  8. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +29 -10
  9. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  10. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  11. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  12. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +69 -0
  13. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +446 -0
  14. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1057 -0
  15. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  16. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  17. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -0
  18. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -1
  19. optimum/rbln/transformers/models/phi/phi_architecture.py +4 -1
  20. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -2
  21. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc0.dist-info}/METADATA +1 -1
  22. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc0.dist-info}/RECORD +24 -20
  23. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc0.dist-info}/WHEEL +0 -0
  24. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc0.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -38,6 +38,7 @@ _import_structure = {
38
38
  "RBLNAutoModelForCTC",
39
39
  "RBLNAutoModelForDepthEstimation",
40
40
  "RBLNAutoModelForImageClassification",
41
+ "RBLNAutoModelForImageTextToText",
41
42
  "RBLNAutoModelForMaskedLM",
42
43
  "RBLNAutoModelForQuestionAnswering",
43
44
  "RBLNAutoModelForSeq2SeqLM",
@@ -78,6 +79,10 @@ _import_structure = {
78
79
  "RBLNExaoneForCausalLMConfig",
79
80
  "RBLNGemmaForCausalLM",
80
81
  "RBLNGemmaForCausalLMConfig",
82
+ "RBLNGemma3ForCausalLM",
83
+ "RBLNGemma3ForCausalLMConfig",
84
+ "RBLNGemma3ForConditionalGeneration",
85
+ "RBLNGemma3ForConditionalGenerationConfig",
81
86
  "RBLNGPT2LMHeadModel",
82
87
  "RBLNGPT2LMHeadModelConfig",
83
88
  "RBLNIdefics3VisionTransformer",
@@ -259,6 +264,7 @@ if TYPE_CHECKING:
259
264
  RBLNAutoModelForCTC,
260
265
  RBLNAutoModelForDepthEstimation,
261
266
  RBLNAutoModelForImageClassification,
267
+ RBLNAutoModelForImageTextToText,
262
268
  RBLNAutoModelForMaskedLM,
263
269
  RBLNAutoModelForQuestionAnswering,
264
270
  RBLNAutoModelForSeq2SeqLM,
@@ -297,6 +303,10 @@ if TYPE_CHECKING:
297
303
  RBLNDPTForDepthEstimationConfig,
298
304
  RBLNExaoneForCausalLM,
299
305
  RBLNExaoneForCausalLMConfig,
306
+ RBLNGemma3ForCausalLM,
307
+ RBLNGemma3ForCausalLMConfig,
308
+ RBLNGemma3ForConditionalGeneration,
309
+ RBLNGemma3ForConditionalGenerationConfig,
300
310
  RBLNGemmaForCausalLM,
301
311
  RBLNGemmaForCausalLMConfig,
302
312
  RBLNGPT2LMHeadModel,
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.5a1'
21
- __version_tuple__ = version_tuple = (0, 7, 5, 'a1')
20
+ __version__ = version = '0.7.5rc0'
21
+ __version_tuple__ = version_tuple = (0, 7, 5, 'rc0')
@@ -34,6 +34,7 @@ _import_structure = {
34
34
  "RBLNAutoModelForCTC",
35
35
  "RBLNAutoModelForDepthEstimation",
36
36
  "RBLNAutoModelForImageClassification",
37
+ "RBLNAutoModelForImageTextToText",
37
38
  "RBLNAutoModelForMaskedLM",
38
39
  "RBLNAutoModelForQuestionAnswering",
39
40
  "RBLNAutoModelForSeq2SeqLM",
@@ -72,6 +73,10 @@ _import_structure = {
72
73
  "RBLNExaoneForCausalLMConfig",
73
74
  "RBLNGemmaForCausalLM",
74
75
  "RBLNGemmaForCausalLMConfig",
76
+ "RBLNGemma3ForCausalLM",
77
+ "RBLNGemma3ForCausalLMConfig",
78
+ "RBLNGemma3ForConditionalGeneration",
79
+ "RBLNGemma3ForConditionalGenerationConfig",
75
80
  "RBLNGPT2LMHeadModel",
76
81
  "RBLNGPT2LMHeadModelConfig",
77
82
  "RBLNIdefics3VisionTransformer",
@@ -148,6 +153,7 @@ if TYPE_CHECKING:
148
153
  RBLNAutoModelForCTC,
149
154
  RBLNAutoModelForDepthEstimation,
150
155
  RBLNAutoModelForImageClassification,
156
+ RBLNAutoModelForImageTextToText,
151
157
  RBLNAutoModelForMaskedLM,
152
158
  RBLNAutoModelForQuestionAnswering,
153
159
  RBLNAutoModelForSeq2SeqLM,
@@ -184,6 +190,10 @@ if TYPE_CHECKING:
184
190
  RBLNDPTForDepthEstimationConfig,
185
191
  RBLNExaoneForCausalLM,
186
192
  RBLNExaoneForCausalLMConfig,
193
+ RBLNGemma3ForCausalLM,
194
+ RBLNGemma3ForCausalLMConfig,
195
+ RBLNGemma3ForConditionalGeneration,
196
+ RBLNGemma3ForConditionalGenerationConfig,
187
197
  RBLNGemmaForCausalLM,
188
198
  RBLNGemmaForCausalLMConfig,
189
199
  RBLNGPT2LMHeadModel,
@@ -31,6 +31,7 @@ _import_structure = {
31
31
  "RBLNAutoModelForSequenceClassification",
32
32
  "RBLNAutoModelForSpeechSeq2Seq",
33
33
  "RBLNAutoModelForVision2Seq",
34
+ "RBLNAutoModelForImageTextToText",
34
35
  ],
35
36
  "bart": [
36
37
  "RBLNBartForConditionalGeneration",
@@ -80,6 +81,12 @@ _import_structure = {
80
81
  ],
81
82
  "exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
82
83
  "gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig"],
84
+ "gemma3": [
85
+ "RBLNGemma3ForCausalLM",
86
+ "RBLNGemma3ForCausalLMConfig",
87
+ "RBLNGemma3ForConditionalGeneration",
88
+ "RBLNGemma3ForConditionalGenerationConfig",
89
+ ],
83
90
  "gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig"],
84
91
  "idefics3": [
85
92
  "RBLNIdefics3VisionTransformer",
@@ -121,6 +128,7 @@ if TYPE_CHECKING:
121
128
  RBLNAutoModelForCTC,
122
129
  RBLNAutoModelForDepthEstimation,
123
130
  RBLNAutoModelForImageClassification,
131
+ RBLNAutoModelForImageTextToText,
124
132
  RBLNAutoModelForMaskedLM,
125
133
  RBLNAutoModelForQuestionAnswering,
126
134
  RBLNAutoModelForSeq2SeqLM,
@@ -170,6 +178,12 @@ if TYPE_CHECKING:
170
178
  )
171
179
  from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
172
180
  from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig
181
+ from .gemma3 import (
182
+ RBLNGemma3ForCausalLM,
183
+ RBLNGemma3ForCausalLMConfig,
184
+ RBLNGemma3ForConditionalGeneration,
185
+ RBLNGemma3ForConditionalGenerationConfig,
186
+ )
173
187
  from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig
174
188
  from .idefics3 import (
175
189
  RBLNIdefics3ForConditionalGeneration,
@@ -19,6 +19,7 @@ from .modeling_auto import (
19
19
  RBLNAutoModelForCTC,
20
20
  RBLNAutoModelForDepthEstimation,
21
21
  RBLNAutoModelForImageClassification,
22
+ RBLNAutoModelForImageTextToText,
22
23
  RBLNAutoModelForMaskedLM,
23
24
  RBLNAutoModelForQuestionAnswering,
24
25
  RBLNAutoModelForSeq2SeqLM,
@@ -23,6 +23,8 @@ from transformers.models.auto.modeling_auto import (
23
23
  MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
24
24
  MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
25
25
  MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
26
+ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING,
27
+ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
26
28
  MODEL_FOR_MASKED_LM_MAPPING,
27
29
  MODEL_FOR_MASKED_LM_MAPPING_NAMES,
28
30
  MODEL_FOR_QUESTION_ANSWERING_MAPPING,
@@ -90,6 +92,11 @@ class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
90
92
  _model_mapping_names = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
91
93
 
92
94
 
95
+ class RBLNAutoModelForImageTextToText(_BaseAutoModelClass):
96
+ _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
97
+ _model_mapping_names = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
98
+
99
+
93
100
  class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
94
101
  _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
95
102
  _model_mapping_names = MODEL_FOR_MASKED_LM_MAPPING_NAMES
@@ -157,7 +157,11 @@ class DecoderOnlyWrapper(nn.Module):
157
157
  self.config = causal_lm.config
158
158
 
159
159
  if use_rotary_emb:
160
- self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
160
+ rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
161
+ if isinstance(rotary_embs, tuple):
162
+ self.rotary_emb_global, self.rotary_emb_local = rotary_embs
163
+ else:
164
+ self.rotary_emb = rotary_embs
161
165
  else:
162
166
  self.rotary_emb = None
163
167
 
@@ -195,7 +199,10 @@ class DecoderOnlyWrapper(nn.Module):
195
199
  for layer in causal_lm.model.layers:
196
200
  if self.attn_impl == "eager":
197
201
  new_self_attn = DecoderOnlyAttention(
198
- layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
202
+ layer.self_attn,
203
+ self.use_attention_mask,
204
+ self.use_position_ids,
205
+ kvcache_block_size=self.kvcache_block_size,
199
206
  )
200
207
  elif self.attn_impl == "flash_attn":
201
208
  new_self_attn = DecoderOnlyFlashAttention(
@@ -203,6 +210,7 @@ class DecoderOnlyWrapper(nn.Module):
203
210
  kvcache_partition_len=self.kvcache_partition_len,
204
211
  kvcache_block_size=self.kvcache_block_size,
205
212
  use_attention_mask=self.use_attention_mask,
213
+ use_position_ids=self.use_position_ids,
206
214
  )
207
215
  else:
208
216
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -363,6 +371,13 @@ class DecoderOnlyForCausalLM(nn.Module):
363
371
  hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
364
372
 
365
373
  logits = self.lm_head(hidden_states)
374
+
375
+ # Apply final logit softmaxing if configured, e.g. for Gemma2
376
+ if getattr(self.config, "final_logit_softcapping", None) is not None:
377
+ logits = logits / self.config.final_logit_softcapping
378
+ logits = torch.tanh(logits)
379
+ logits = logits * self.config.final_logit_softcapping
380
+
366
381
  return logits
367
382
 
368
383
 
@@ -610,7 +625,7 @@ class DecoderOnlyAttention(nn.Module):
610
625
  self_attn: Original attention module from the base model
611
626
  """
612
627
 
613
- def __init__(self, self_attn, use_attention_mask, kvcache_block_size):
628
+ def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size):
614
629
  super().__init__()
615
630
  self._original_mod = self_attn
616
631
  self.layer_idx = self_attn.layer_idx
@@ -629,6 +644,7 @@ class DecoderOnlyAttention(nn.Module):
629
644
  self.num_key_value_heads = self.num_heads
630
645
 
631
646
  self.use_attention_mask = use_attention_mask
647
+ self.use_position_ids = use_position_ids
632
648
  self.attention = self.get_attention()
633
649
  self.kvcache_block_size = kvcache_block_size
634
650
  self.__post_init__()
@@ -643,7 +659,9 @@ class DecoderOnlyAttention(nn.Module):
643
659
  self.attention.phase = phase
644
660
 
645
661
  def get_attention(self):
646
- return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask)
662
+ return AttentionOp(
663
+ self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
664
+ )
647
665
 
648
666
  def __post_init__(self):
649
667
  self.q_proj = self._original_mod.q_proj
@@ -716,13 +734,16 @@ class DecoderOnlyAttention(nn.Module):
716
734
 
717
735
 
718
736
  class AttentionOp(nn.Module):
719
- def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool):
737
+ def __init__(
738
+ self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
739
+ ):
720
740
  super().__init__()
721
741
  self.num_heads = num_heads
722
742
  self.head_dim = head_dim
723
743
  self.num_key_value_heads = num_key_value_heads
724
744
  self.phase = "prefill"
725
745
  self.use_attention_mask = use_attention_mask
746
+ self.use_position_ids = use_position_ids
726
747
 
727
748
  def forward(
728
749
  self,
@@ -755,7 +776,8 @@ class AttentionOp(nn.Module):
755
776
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
756
777
  key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
757
778
  value_state = value_state.unsqueeze(2)
758
- if self.use_attention_mask:
779
+
780
+ if self.use_attention_mask and not self.use_position_ids:
759
781
  attn_mask = attn_mask.unsqueeze(2)
760
782
 
761
783
  if self.phase == "decode":
@@ -772,7 +794,7 @@ class AttentionOp(nn.Module):
772
794
  )
773
795
 
774
796
  if self.phase == "decode":
775
- if self.use_attention_mask:
797
+ if self.use_attention_mask and not self.use_position_ids:
776
798
  attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
777
799
  q=query_state,
778
800
  k=key_state,
@@ -796,11 +818,11 @@ class AttentionOp(nn.Module):
796
818
  scale=scale,
797
819
  block_table=block_tables,
798
820
  block_size=block_size,
799
- mask=None,
821
+ mask=attn_mask if self.use_position_ids else None,
800
822
  )
801
823
 
802
824
  else:
803
- if self.use_attention_mask:
825
+ if self.use_attention_mask and not self.use_position_ids:
804
826
  attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
805
827
  q=query_state,
806
828
  k=key_state,
@@ -824,8 +846,8 @@ class AttentionOp(nn.Module):
824
846
  scale=scale,
825
847
  block_table=block_tables,
826
848
  block_size=block_size,
827
- is_bidirectional=False,
828
- mask=None,
849
+ is_bidirectional=True if self.phase == "image_prefill" else False, # FIXME, Hard-coded for Gemma3.
850
+ mask=attn_mask if self.use_position_ids else None,
829
851
  )
830
852
 
831
853
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
@@ -927,10 +949,13 @@ class RotaryEmbedding(nn.Module):
927
949
 
928
950
 
929
951
  class DecoderOnlyFlashAttention(DecoderOnlyAttention):
930
- def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask):
952
+ def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask, use_position_ids):
931
953
  self.kvcache_partition_size = kvcache_partition_len
932
954
  super().__init__(
933
- self_attn=self_attn, use_attention_mask=use_attention_mask, kvcache_block_size=kvcache_block_size
955
+ self_attn=self_attn,
956
+ use_attention_mask=use_attention_mask,
957
+ use_position_ids=use_position_ids,
958
+ kvcache_block_size=kvcache_block_size,
934
959
  )
935
960
 
936
961
  def get_attention(self):
@@ -940,6 +965,7 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
940
965
  self.num_key_value_heads,
941
966
  self.kvcache_partition_size,
942
967
  self.use_attention_mask,
968
+ self.use_position_ids,
943
969
  )
944
970
 
945
971
  def forward(
@@ -991,12 +1017,14 @@ class FlashAttentionOp(AttentionOp):
991
1017
  num_key_value_heads: int,
992
1018
  kvcache_partition_len: int,
993
1019
  use_attention_mask: bool,
1020
+ use_position_ids: bool,
994
1021
  ):
995
1022
  super().__init__(
996
1023
  num_heads=num_heads,
997
1024
  head_dim=head_dim,
998
1025
  num_key_value_heads=num_key_value_heads,
999
1026
  use_attention_mask=use_attention_mask,
1027
+ use_position_ids=use_position_ids,
1000
1028
  )
1001
1029
  self.kvcache_partition_size = kvcache_partition_len
1002
1030
 
@@ -1016,7 +1044,7 @@ class FlashAttentionOp(AttentionOp):
1016
1044
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1017
1045
  key_state = key_state.unsqueeze(2)
1018
1046
  value_state = value_state.unsqueeze(2)
1019
- if self.use_attention_mask:
1047
+ if self.use_attention_mask and not self.use_position_ids:
1020
1048
  attn_mask = attn_mask.unsqueeze(2)
1021
1049
 
1022
1050
  if self.phase == "decode":
@@ -1033,7 +1061,7 @@ class FlashAttentionOp(AttentionOp):
1033
1061
  )
1034
1062
 
1035
1063
  if self.phase == "decode":
1036
- if self.use_attention_mask:
1064
+ if self.use_attention_mask and not self.use_position_ids:
1037
1065
  attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1038
1066
  q=query_state,
1039
1067
  k=key_state,
@@ -1059,10 +1087,10 @@ class FlashAttentionOp(AttentionOp):
1059
1087
  block_table=block_tables,
1060
1088
  block_size=kvcache_block_size,
1061
1089
  partition=self.kvcache_partition_size,
1062
- mask=None,
1090
+ mask=attn_mask if self.use_position_ids else None,
1063
1091
  )
1064
1092
  else:
1065
- if self.use_attention_mask:
1093
+ if self.use_attention_mask and not self.use_position_ids:
1066
1094
  attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1067
1095
  q=query_state,
1068
1096
  k=key_state,
@@ -1088,8 +1116,8 @@ class FlashAttentionOp(AttentionOp):
1088
1116
  block_table=block_tables,
1089
1117
  block_size=kvcache_block_size,
1090
1118
  partition=self.kvcache_partition_size,
1091
- is_bidirectional=False,
1092
- mask=None,
1119
+ is_bidirectional=True if self.phase == "image_prefill" else False,
1120
+ mask=attn_mask if self.use_position_ids else None,
1093
1121
  )
1094
1122
 
1095
1123
  # reshape for removing repeat_kv
@@ -1098,3 +1126,70 @@ class FlashAttentionOp(AttentionOp):
1098
1126
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1099
1127
 
1100
1128
  return attn_output
1129
+
1130
+
1131
+ class SlidingWindowAttentionOp(AttentionOp):
1132
+ def forward(
1133
+ self,
1134
+ query_state: torch.Tensor,
1135
+ key_state: torch.Tensor,
1136
+ value_state: torch.Tensor,
1137
+ attn_mask: torch.Tensor,
1138
+ past_key_state: torch.Tensor,
1139
+ past_value_state: torch.Tensor,
1140
+ seq_position: Tuple[torch.Tensor],
1141
+ scale: torch.Tensor,
1142
+ block_tables: torch.Tensor,
1143
+ block_size: int,
1144
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1145
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1146
+ key_state = key_state.unsqueeze(2)
1147
+ value_state = value_state.unsqueeze(2)
1148
+
1149
+ if self.phase == "decode":
1150
+ batch_size = key_state.shape[0]
1151
+ else:
1152
+ batch_size = 1
1153
+
1154
+ query_state = query_state.view(
1155
+ batch_size,
1156
+ self.num_key_value_heads,
1157
+ self.num_heads // self.num_key_value_heads,
1158
+ -1, # seq len
1159
+ self.head_dim,
1160
+ )
1161
+
1162
+ if self.phase == "decode":
1163
+ attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_decode(
1164
+ q=query_state,
1165
+ k=key_state,
1166
+ v=value_state,
1167
+ kcache=past_key_state.unsqueeze(2),
1168
+ vcache=past_value_state.unsqueeze(2),
1169
+ cache_seq_len=seq_position[0],
1170
+ cache_offset=seq_position[1],
1171
+ scale=scale,
1172
+ block_table=block_tables,
1173
+ block_size=block_size,
1174
+ )
1175
+ else:
1176
+ attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_prefill(
1177
+ q=query_state,
1178
+ k=key_state,
1179
+ v=value_state,
1180
+ kcache=past_key_state.unsqueeze(2),
1181
+ vcache=past_value_state.unsqueeze(2),
1182
+ cache_seq_len=seq_position[0],
1183
+ cache_offset=seq_position[1],
1184
+ scale=scale,
1185
+ block_table=block_tables,
1186
+ block_size=block_size,
1187
+ is_bidirectional=True if self.phase == "image_prefill" else False,
1188
+ )
1189
+
1190
+ # reshape for removing repeat_kv
1191
+ attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
1192
+ attn_output = attn_output.transpose(1, 2).contiguous()
1193
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1194
+
1195
+ return attn_output
@@ -167,6 +167,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
167
167
  block_tables: Optional[torch.Tensor] = None,
168
168
  position_embed: Optional[torch.Tensor] = None,
169
169
  position_ids: Optional[torch.Tensor] = None,
170
+ token_type_ids: Optional[torch.Tensor] = None,
171
+ local_block_tables: Optional[torch.Tensor] = None,
170
172
  ):
171
173
  if input_ids is None and inputs_embeds is None:
172
174
  raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
@@ -193,6 +195,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
193
195
  attention_mask=attention_mask,
194
196
  position_embed=position_embed,
195
197
  position_ids=position_ids,
198
+ local_block_tables=local_block_tables,
196
199
  )
197
200
  else:
198
201
  return self.prefill_forward(
@@ -202,6 +205,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
202
205
  batch_idx,
203
206
  block_tables,
204
207
  position_embed=position_embed,
208
+ token_type_ids=token_type_ids,
209
+ local_block_tables=local_block_tables,
205
210
  )
206
211
 
207
212
  def decode_forward(
@@ -213,6 +218,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
213
218
  attention_mask: Optional[torch.Tensor] = None,
214
219
  position_embed: Optional[torch.Tensor] = None,
215
220
  position_ids: Optional[torch.Tensor] = None,
221
+ local_block_tables: Optional[torch.Tensor] = None,
216
222
  ) -> torch.FloatTensor:
217
223
  batch_size = inputs.shape[0]
218
224
  if batch_size != self.batch_size:
@@ -262,6 +268,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
262
268
  cache_position: torch.Tensor,
263
269
  attention_mask: Optional[torch.Tensor] = None,
264
270
  position_embed: Optional[torch.Tensor] = None,
271
+ local_block_tables: Optional[torch.Tensor] = None,
272
+ token_type_ids: Optional[torch.Tensor] = None,
265
273
  ):
266
274
  """
267
275
  Prepare inputs for prefill phase.
@@ -345,6 +353,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
345
353
  block_tables: torch.Tensor = None,
346
354
  is_external_block_tables: bool = None,
347
355
  position_embed: Optional[torch.Tensor] = None,
356
+ local_block_tables: Optional[torch.Tensor] = None,
357
+ token_type_ids: Optional[torch.Tensor] = None,
348
358
  ) -> torch.FloatTensor:
349
359
  """
350
360
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
@@ -360,7 +370,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
360
370
  position_embed,
361
371
  padded_cache_lengths,
362
372
  query_length,
363
- ) = self._prepare_prefill_inputs(inputs, cache_position, attention_mask, position_embed)
373
+ ) = self._prepare_prefill_inputs(
374
+ inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
375
+ )
364
376
 
365
377
  # Process input in chunks of size `prefill_chunk_size`
366
378
  for step in range(0, query_length, self.prefill_chunk_size):
@@ -373,7 +385,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
373
385
  if position_embed is not None:
374
386
  position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
375
387
 
376
- if self.use_attention_mask:
388
+ if self.use_attention_mask and not self.use_position_ids:
377
389
  # Update attention mask to ensure proper causal behavior
378
390
  if step >= self.prefill_chunk_size:
379
391
  chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
@@ -387,10 +399,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
387
399
  input_chunk,
388
400
  cache_pos_chunk,
389
401
  block_tables,
402
+ position_embed_chunk if position_embed is not None else None,
390
403
  query_position,
391
404
  chunked_attention_mask if self.use_attention_mask else None,
392
- position_ids_chunk if position_ids is not None else None,
393
- position_embed_chunk if position_embed is not None else None,
405
+ position_ids_chunk if self.use_position_ids else None,
394
406
  out=out_buffers,
395
407
  )
396
408
 
@@ -440,12 +452,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
440
452
  if self.rbln_config.use_inputs_embeds:
441
453
  main_input_name = "inputs_embeds"
442
454
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
443
- with no_init_weights():
444
- self.embed_tokens = torch.nn.Embedding(
445
- self.config.vocab_size,
446
- self.config.hidden_size,
447
- self.config.pad_token_id,
448
- )
455
+ self.embed_tokens = self._create_embedding_layer()
449
456
  self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
450
457
  else:
451
458
  self.embed_tokens = None
@@ -478,6 +485,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
478
485
  attn_impl=self.rbln_config.attn_impl,
479
486
  use_position_ids=self.rbln_config.use_position_ids,
480
487
  )
488
+
481
489
  self.decoders = {}
482
490
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
483
491
  self.decoders[batch_size] = RBLNRuntimeModel(
@@ -515,6 +523,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
515
523
  save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
516
524
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
517
525
 
526
+ def _create_embedding_layer(self):
527
+ with no_init_weights():
528
+ embed_tokens = torch.nn.Embedding(
529
+ self.config.vocab_size,
530
+ self.config.hidden_size,
531
+ self.config.pad_token_id,
532
+ )
533
+ return embed_tokens
534
+
518
535
  def get_input_embeddings(self):
519
536
  return self.embed_tokens
520
537
 
@@ -1101,6 +1118,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1101
1118
  generate_idx: Optional[torch.Tensor] = None,
1102
1119
  padded_cache_lengths: Optional[torch.Tensor] = None,
1103
1120
  position_ids: Optional[torch.Tensor] = None,
1121
+ token_type_ids: Optional[torch.Tensor] = None,
1104
1122
  return_dict: Optional[torch.Tensor] = None,
1105
1123
  **kwargs,
1106
1124
  ) -> Tuple[torch.FloatTensor]:
@@ -1123,6 +1141,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1123
1141
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
1124
1142
  cache_position=cache_position,
1125
1143
  batch_idx=b_idx,
1144
+ token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
1126
1145
  )
1127
1146
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
1128
1147
  logits.append(output.logits)
@@ -41,7 +41,10 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
41
41
  for layer in causal_lm.transformer.h:
42
42
  if self.attn_impl == "eager":
43
43
  new_self_attn = ExaoneAttention(
44
- layer.attn.attention, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
44
+ layer.attn.attention,
45
+ self.use_attention_mask,
46
+ kvcache_block_size=self.kvcache_block_size,
47
+ use_position_ids=self.use_position_ids,
45
48
  )
46
49
  elif self.attn_impl == "flash_attn":
47
50
  new_self_attn = ExaoneFlashAttention(
@@ -49,6 +52,7 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
49
52
  kvcache_partition_len=self.kvcache_partition_len,
50
53
  use_attention_mask=self.use_attention_mask,
51
54
  kvcache_block_size=self.kvcache_block_size,
55
+ use_position_ids=self.use_position_ids,
52
56
  )
53
57
  else:
54
58
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -34,7 +34,10 @@ class GemmaWrapper(DecoderOnlyWrapper):
34
34
  for layer in causal_lm.model.layers:
35
35
  if self.attn_impl == "eager":
36
36
  new_self_attn = DecoderOnlyAttention(
37
- layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
37
+ layer.self_attn,
38
+ self.use_attention_mask,
39
+ kvcache_block_size=self.kvcache_block_size,
40
+ use_position_ids=self.use_position_ids,
38
41
  )
39
42
  elif self.attn_impl == "flash_attn":
40
43
  new_self_attn = DecoderOnlyFlashAttention(
@@ -42,6 +45,7 @@ class GemmaWrapper(DecoderOnlyWrapper):
42
45
  kvcache_partition_len=self.kvcache_partition_len,
43
46
  use_attention_mask=self.use_attention_mask,
44
47
  kvcache_block_size=self.kvcache_block_size,
48
+ use_position_ids=self.use_position_ids,
45
49
  )
46
50
  else:
47
51
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig, RBLNGemma3ForConditionalGenerationConfig
16
+ from .modeling_gemma3 import RBLNGemma3ForCausalLM, RBLNGemma3ForConditionalGeneration