optimum-rbln 0.7.5a1__py3-none-any.whl → 0.7.5rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +10 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/transformers/__init__.py +10 -0
- optimum/rbln/transformers/models/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/__init__.py +1 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +114 -19
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +29 -10
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +69 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +446 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1057 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +4 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -2
- {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc0.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc0.dist-info}/RECORD +24 -20
- {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc0.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -38,6 +38,7 @@ _import_structure = {
|
|
38
38
|
"RBLNAutoModelForCTC",
|
39
39
|
"RBLNAutoModelForDepthEstimation",
|
40
40
|
"RBLNAutoModelForImageClassification",
|
41
|
+
"RBLNAutoModelForImageTextToText",
|
41
42
|
"RBLNAutoModelForMaskedLM",
|
42
43
|
"RBLNAutoModelForQuestionAnswering",
|
43
44
|
"RBLNAutoModelForSeq2SeqLM",
|
@@ -78,6 +79,10 @@ _import_structure = {
|
|
78
79
|
"RBLNExaoneForCausalLMConfig",
|
79
80
|
"RBLNGemmaForCausalLM",
|
80
81
|
"RBLNGemmaForCausalLMConfig",
|
82
|
+
"RBLNGemma3ForCausalLM",
|
83
|
+
"RBLNGemma3ForCausalLMConfig",
|
84
|
+
"RBLNGemma3ForConditionalGeneration",
|
85
|
+
"RBLNGemma3ForConditionalGenerationConfig",
|
81
86
|
"RBLNGPT2LMHeadModel",
|
82
87
|
"RBLNGPT2LMHeadModelConfig",
|
83
88
|
"RBLNIdefics3VisionTransformer",
|
@@ -259,6 +264,7 @@ if TYPE_CHECKING:
|
|
259
264
|
RBLNAutoModelForCTC,
|
260
265
|
RBLNAutoModelForDepthEstimation,
|
261
266
|
RBLNAutoModelForImageClassification,
|
267
|
+
RBLNAutoModelForImageTextToText,
|
262
268
|
RBLNAutoModelForMaskedLM,
|
263
269
|
RBLNAutoModelForQuestionAnswering,
|
264
270
|
RBLNAutoModelForSeq2SeqLM,
|
@@ -297,6 +303,10 @@ if TYPE_CHECKING:
|
|
297
303
|
RBLNDPTForDepthEstimationConfig,
|
298
304
|
RBLNExaoneForCausalLM,
|
299
305
|
RBLNExaoneForCausalLMConfig,
|
306
|
+
RBLNGemma3ForCausalLM,
|
307
|
+
RBLNGemma3ForCausalLMConfig,
|
308
|
+
RBLNGemma3ForConditionalGeneration,
|
309
|
+
RBLNGemma3ForConditionalGenerationConfig,
|
300
310
|
RBLNGemmaForCausalLM,
|
301
311
|
RBLNGemmaForCausalLMConfig,
|
302
312
|
RBLNGPT2LMHeadModel,
|
optimum/rbln/__version__.py
CHANGED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.7.
|
21
|
-
__version_tuple__ = version_tuple = (0, 7, 5, '
|
20
|
+
__version__ = version = '0.7.5rc0'
|
21
|
+
__version_tuple__ = version_tuple = (0, 7, 5, 'rc0')
|
@@ -34,6 +34,7 @@ _import_structure = {
|
|
34
34
|
"RBLNAutoModelForCTC",
|
35
35
|
"RBLNAutoModelForDepthEstimation",
|
36
36
|
"RBLNAutoModelForImageClassification",
|
37
|
+
"RBLNAutoModelForImageTextToText",
|
37
38
|
"RBLNAutoModelForMaskedLM",
|
38
39
|
"RBLNAutoModelForQuestionAnswering",
|
39
40
|
"RBLNAutoModelForSeq2SeqLM",
|
@@ -72,6 +73,10 @@ _import_structure = {
|
|
72
73
|
"RBLNExaoneForCausalLMConfig",
|
73
74
|
"RBLNGemmaForCausalLM",
|
74
75
|
"RBLNGemmaForCausalLMConfig",
|
76
|
+
"RBLNGemma3ForCausalLM",
|
77
|
+
"RBLNGemma3ForCausalLMConfig",
|
78
|
+
"RBLNGemma3ForConditionalGeneration",
|
79
|
+
"RBLNGemma3ForConditionalGenerationConfig",
|
75
80
|
"RBLNGPT2LMHeadModel",
|
76
81
|
"RBLNGPT2LMHeadModelConfig",
|
77
82
|
"RBLNIdefics3VisionTransformer",
|
@@ -148,6 +153,7 @@ if TYPE_CHECKING:
|
|
148
153
|
RBLNAutoModelForCTC,
|
149
154
|
RBLNAutoModelForDepthEstimation,
|
150
155
|
RBLNAutoModelForImageClassification,
|
156
|
+
RBLNAutoModelForImageTextToText,
|
151
157
|
RBLNAutoModelForMaskedLM,
|
152
158
|
RBLNAutoModelForQuestionAnswering,
|
153
159
|
RBLNAutoModelForSeq2SeqLM,
|
@@ -184,6 +190,10 @@ if TYPE_CHECKING:
|
|
184
190
|
RBLNDPTForDepthEstimationConfig,
|
185
191
|
RBLNExaoneForCausalLM,
|
186
192
|
RBLNExaoneForCausalLMConfig,
|
193
|
+
RBLNGemma3ForCausalLM,
|
194
|
+
RBLNGemma3ForCausalLMConfig,
|
195
|
+
RBLNGemma3ForConditionalGeneration,
|
196
|
+
RBLNGemma3ForConditionalGenerationConfig,
|
187
197
|
RBLNGemmaForCausalLM,
|
188
198
|
RBLNGemmaForCausalLMConfig,
|
189
199
|
RBLNGPT2LMHeadModel,
|
@@ -31,6 +31,7 @@ _import_structure = {
|
|
31
31
|
"RBLNAutoModelForSequenceClassification",
|
32
32
|
"RBLNAutoModelForSpeechSeq2Seq",
|
33
33
|
"RBLNAutoModelForVision2Seq",
|
34
|
+
"RBLNAutoModelForImageTextToText",
|
34
35
|
],
|
35
36
|
"bart": [
|
36
37
|
"RBLNBartForConditionalGeneration",
|
@@ -80,6 +81,12 @@ _import_structure = {
|
|
80
81
|
],
|
81
82
|
"exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
|
82
83
|
"gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig"],
|
84
|
+
"gemma3": [
|
85
|
+
"RBLNGemma3ForCausalLM",
|
86
|
+
"RBLNGemma3ForCausalLMConfig",
|
87
|
+
"RBLNGemma3ForConditionalGeneration",
|
88
|
+
"RBLNGemma3ForConditionalGenerationConfig",
|
89
|
+
],
|
83
90
|
"gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig"],
|
84
91
|
"idefics3": [
|
85
92
|
"RBLNIdefics3VisionTransformer",
|
@@ -121,6 +128,7 @@ if TYPE_CHECKING:
|
|
121
128
|
RBLNAutoModelForCTC,
|
122
129
|
RBLNAutoModelForDepthEstimation,
|
123
130
|
RBLNAutoModelForImageClassification,
|
131
|
+
RBLNAutoModelForImageTextToText,
|
124
132
|
RBLNAutoModelForMaskedLM,
|
125
133
|
RBLNAutoModelForQuestionAnswering,
|
126
134
|
RBLNAutoModelForSeq2SeqLM,
|
@@ -170,6 +178,12 @@ if TYPE_CHECKING:
|
|
170
178
|
)
|
171
179
|
from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
|
172
180
|
from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig
|
181
|
+
from .gemma3 import (
|
182
|
+
RBLNGemma3ForCausalLM,
|
183
|
+
RBLNGemma3ForCausalLMConfig,
|
184
|
+
RBLNGemma3ForConditionalGeneration,
|
185
|
+
RBLNGemma3ForConditionalGenerationConfig,
|
186
|
+
)
|
173
187
|
from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig
|
174
188
|
from .idefics3 import (
|
175
189
|
RBLNIdefics3ForConditionalGeneration,
|
@@ -23,6 +23,8 @@ from transformers.models.auto.modeling_auto import (
|
|
23
23
|
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
|
24
24
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
25
25
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
26
|
+
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING,
|
27
|
+
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
26
28
|
MODEL_FOR_MASKED_LM_MAPPING,
|
27
29
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
28
30
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
@@ -90,6 +92,11 @@ class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
|
|
90
92
|
_model_mapping_names = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
|
91
93
|
|
92
94
|
|
95
|
+
class RBLNAutoModelForImageTextToText(_BaseAutoModelClass):
|
96
|
+
_model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
|
97
|
+
_model_mapping_names = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
|
98
|
+
|
99
|
+
|
93
100
|
class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
|
94
101
|
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
95
102
|
_model_mapping_names = MODEL_FOR_MASKED_LM_MAPPING_NAMES
|
@@ -157,7 +157,11 @@ class DecoderOnlyWrapper(nn.Module):
|
|
157
157
|
self.config = causal_lm.config
|
158
158
|
|
159
159
|
if use_rotary_emb:
|
160
|
-
|
160
|
+
rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
|
161
|
+
if isinstance(rotary_embs, tuple):
|
162
|
+
self.rotary_emb_global, self.rotary_emb_local = rotary_embs
|
163
|
+
else:
|
164
|
+
self.rotary_emb = rotary_embs
|
161
165
|
else:
|
162
166
|
self.rotary_emb = None
|
163
167
|
|
@@ -195,7 +199,10 @@ class DecoderOnlyWrapper(nn.Module):
|
|
195
199
|
for layer in causal_lm.model.layers:
|
196
200
|
if self.attn_impl == "eager":
|
197
201
|
new_self_attn = DecoderOnlyAttention(
|
198
|
-
layer.self_attn,
|
202
|
+
layer.self_attn,
|
203
|
+
self.use_attention_mask,
|
204
|
+
self.use_position_ids,
|
205
|
+
kvcache_block_size=self.kvcache_block_size,
|
199
206
|
)
|
200
207
|
elif self.attn_impl == "flash_attn":
|
201
208
|
new_self_attn = DecoderOnlyFlashAttention(
|
@@ -203,6 +210,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
203
210
|
kvcache_partition_len=self.kvcache_partition_len,
|
204
211
|
kvcache_block_size=self.kvcache_block_size,
|
205
212
|
use_attention_mask=self.use_attention_mask,
|
213
|
+
use_position_ids=self.use_position_ids,
|
206
214
|
)
|
207
215
|
else:
|
208
216
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -363,6 +371,13 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
363
371
|
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
364
372
|
|
365
373
|
logits = self.lm_head(hidden_states)
|
374
|
+
|
375
|
+
# Apply final logit softmaxing if configured, e.g. for Gemma2
|
376
|
+
if getattr(self.config, "final_logit_softcapping", None) is not None:
|
377
|
+
logits = logits / self.config.final_logit_softcapping
|
378
|
+
logits = torch.tanh(logits)
|
379
|
+
logits = logits * self.config.final_logit_softcapping
|
380
|
+
|
366
381
|
return logits
|
367
382
|
|
368
383
|
|
@@ -610,7 +625,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
610
625
|
self_attn: Original attention module from the base model
|
611
626
|
"""
|
612
627
|
|
613
|
-
def __init__(self, self_attn, use_attention_mask, kvcache_block_size):
|
628
|
+
def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size):
|
614
629
|
super().__init__()
|
615
630
|
self._original_mod = self_attn
|
616
631
|
self.layer_idx = self_attn.layer_idx
|
@@ -629,6 +644,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
629
644
|
self.num_key_value_heads = self.num_heads
|
630
645
|
|
631
646
|
self.use_attention_mask = use_attention_mask
|
647
|
+
self.use_position_ids = use_position_ids
|
632
648
|
self.attention = self.get_attention()
|
633
649
|
self.kvcache_block_size = kvcache_block_size
|
634
650
|
self.__post_init__()
|
@@ -643,7 +659,9 @@ class DecoderOnlyAttention(nn.Module):
|
|
643
659
|
self.attention.phase = phase
|
644
660
|
|
645
661
|
def get_attention(self):
|
646
|
-
return AttentionOp(
|
662
|
+
return AttentionOp(
|
663
|
+
self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
|
664
|
+
)
|
647
665
|
|
648
666
|
def __post_init__(self):
|
649
667
|
self.q_proj = self._original_mod.q_proj
|
@@ -716,13 +734,16 @@ class DecoderOnlyAttention(nn.Module):
|
|
716
734
|
|
717
735
|
|
718
736
|
class AttentionOp(nn.Module):
|
719
|
-
def __init__(
|
737
|
+
def __init__(
|
738
|
+
self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
|
739
|
+
):
|
720
740
|
super().__init__()
|
721
741
|
self.num_heads = num_heads
|
722
742
|
self.head_dim = head_dim
|
723
743
|
self.num_key_value_heads = num_key_value_heads
|
724
744
|
self.phase = "prefill"
|
725
745
|
self.use_attention_mask = use_attention_mask
|
746
|
+
self.use_position_ids = use_position_ids
|
726
747
|
|
727
748
|
def forward(
|
728
749
|
self,
|
@@ -755,7 +776,8 @@ class AttentionOp(nn.Module):
|
|
755
776
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
756
777
|
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
757
778
|
value_state = value_state.unsqueeze(2)
|
758
|
-
|
779
|
+
|
780
|
+
if self.use_attention_mask and not self.use_position_ids:
|
759
781
|
attn_mask = attn_mask.unsqueeze(2)
|
760
782
|
|
761
783
|
if self.phase == "decode":
|
@@ -772,7 +794,7 @@ class AttentionOp(nn.Module):
|
|
772
794
|
)
|
773
795
|
|
774
796
|
if self.phase == "decode":
|
775
|
-
if self.use_attention_mask:
|
797
|
+
if self.use_attention_mask and not self.use_position_ids:
|
776
798
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
|
777
799
|
q=query_state,
|
778
800
|
k=key_state,
|
@@ -796,11 +818,11 @@ class AttentionOp(nn.Module):
|
|
796
818
|
scale=scale,
|
797
819
|
block_table=block_tables,
|
798
820
|
block_size=block_size,
|
799
|
-
mask=None,
|
821
|
+
mask=attn_mask if self.use_position_ids else None,
|
800
822
|
)
|
801
823
|
|
802
824
|
else:
|
803
|
-
if self.use_attention_mask:
|
825
|
+
if self.use_attention_mask and not self.use_position_ids:
|
804
826
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
805
827
|
q=query_state,
|
806
828
|
k=key_state,
|
@@ -824,8 +846,8 @@ class AttentionOp(nn.Module):
|
|
824
846
|
scale=scale,
|
825
847
|
block_table=block_tables,
|
826
848
|
block_size=block_size,
|
827
|
-
is_bidirectional=False,
|
828
|
-
mask=None,
|
849
|
+
is_bidirectional=True if self.phase == "image_prefill" else False, # FIXME, Hard-coded for Gemma3.
|
850
|
+
mask=attn_mask if self.use_position_ids else None,
|
829
851
|
)
|
830
852
|
|
831
853
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
@@ -927,10 +949,13 @@ class RotaryEmbedding(nn.Module):
|
|
927
949
|
|
928
950
|
|
929
951
|
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
930
|
-
def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask):
|
952
|
+
def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask, use_position_ids):
|
931
953
|
self.kvcache_partition_size = kvcache_partition_len
|
932
954
|
super().__init__(
|
933
|
-
self_attn=self_attn,
|
955
|
+
self_attn=self_attn,
|
956
|
+
use_attention_mask=use_attention_mask,
|
957
|
+
use_position_ids=use_position_ids,
|
958
|
+
kvcache_block_size=kvcache_block_size,
|
934
959
|
)
|
935
960
|
|
936
961
|
def get_attention(self):
|
@@ -940,6 +965,7 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
940
965
|
self.num_key_value_heads,
|
941
966
|
self.kvcache_partition_size,
|
942
967
|
self.use_attention_mask,
|
968
|
+
self.use_position_ids,
|
943
969
|
)
|
944
970
|
|
945
971
|
def forward(
|
@@ -991,12 +1017,14 @@ class FlashAttentionOp(AttentionOp):
|
|
991
1017
|
num_key_value_heads: int,
|
992
1018
|
kvcache_partition_len: int,
|
993
1019
|
use_attention_mask: bool,
|
1020
|
+
use_position_ids: bool,
|
994
1021
|
):
|
995
1022
|
super().__init__(
|
996
1023
|
num_heads=num_heads,
|
997
1024
|
head_dim=head_dim,
|
998
1025
|
num_key_value_heads=num_key_value_heads,
|
999
1026
|
use_attention_mask=use_attention_mask,
|
1027
|
+
use_position_ids=use_position_ids,
|
1000
1028
|
)
|
1001
1029
|
self.kvcache_partition_size = kvcache_partition_len
|
1002
1030
|
|
@@ -1016,7 +1044,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1016
1044
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
1017
1045
|
key_state = key_state.unsqueeze(2)
|
1018
1046
|
value_state = value_state.unsqueeze(2)
|
1019
|
-
if self.use_attention_mask:
|
1047
|
+
if self.use_attention_mask and not self.use_position_ids:
|
1020
1048
|
attn_mask = attn_mask.unsqueeze(2)
|
1021
1049
|
|
1022
1050
|
if self.phase == "decode":
|
@@ -1033,7 +1061,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1033
1061
|
)
|
1034
1062
|
|
1035
1063
|
if self.phase == "decode":
|
1036
|
-
if self.use_attention_mask:
|
1064
|
+
if self.use_attention_mask and not self.use_position_ids:
|
1037
1065
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
|
1038
1066
|
q=query_state,
|
1039
1067
|
k=key_state,
|
@@ -1059,10 +1087,10 @@ class FlashAttentionOp(AttentionOp):
|
|
1059
1087
|
block_table=block_tables,
|
1060
1088
|
block_size=kvcache_block_size,
|
1061
1089
|
partition=self.kvcache_partition_size,
|
1062
|
-
mask=None,
|
1090
|
+
mask=attn_mask if self.use_position_ids else None,
|
1063
1091
|
)
|
1064
1092
|
else:
|
1065
|
-
if self.use_attention_mask:
|
1093
|
+
if self.use_attention_mask and not self.use_position_ids:
|
1066
1094
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
1067
1095
|
q=query_state,
|
1068
1096
|
k=key_state,
|
@@ -1088,8 +1116,8 @@ class FlashAttentionOp(AttentionOp):
|
|
1088
1116
|
block_table=block_tables,
|
1089
1117
|
block_size=kvcache_block_size,
|
1090
1118
|
partition=self.kvcache_partition_size,
|
1091
|
-
is_bidirectional=False,
|
1092
|
-
mask=None,
|
1119
|
+
is_bidirectional=True if self.phase == "image_prefill" else False,
|
1120
|
+
mask=attn_mask if self.use_position_ids else None,
|
1093
1121
|
)
|
1094
1122
|
|
1095
1123
|
# reshape for removing repeat_kv
|
@@ -1098,3 +1126,70 @@ class FlashAttentionOp(AttentionOp):
|
|
1098
1126
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
1099
1127
|
|
1100
1128
|
return attn_output
|
1129
|
+
|
1130
|
+
|
1131
|
+
class SlidingWindowAttentionOp(AttentionOp):
|
1132
|
+
def forward(
|
1133
|
+
self,
|
1134
|
+
query_state: torch.Tensor,
|
1135
|
+
key_state: torch.Tensor,
|
1136
|
+
value_state: torch.Tensor,
|
1137
|
+
attn_mask: torch.Tensor,
|
1138
|
+
past_key_state: torch.Tensor,
|
1139
|
+
past_value_state: torch.Tensor,
|
1140
|
+
seq_position: Tuple[torch.Tensor],
|
1141
|
+
scale: torch.Tensor,
|
1142
|
+
block_tables: torch.Tensor,
|
1143
|
+
block_size: int,
|
1144
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
1145
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
1146
|
+
key_state = key_state.unsqueeze(2)
|
1147
|
+
value_state = value_state.unsqueeze(2)
|
1148
|
+
|
1149
|
+
if self.phase == "decode":
|
1150
|
+
batch_size = key_state.shape[0]
|
1151
|
+
else:
|
1152
|
+
batch_size = 1
|
1153
|
+
|
1154
|
+
query_state = query_state.view(
|
1155
|
+
batch_size,
|
1156
|
+
self.num_key_value_heads,
|
1157
|
+
self.num_heads // self.num_key_value_heads,
|
1158
|
+
-1, # seq len
|
1159
|
+
self.head_dim,
|
1160
|
+
)
|
1161
|
+
|
1162
|
+
if self.phase == "decode":
|
1163
|
+
attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_decode(
|
1164
|
+
q=query_state,
|
1165
|
+
k=key_state,
|
1166
|
+
v=value_state,
|
1167
|
+
kcache=past_key_state.unsqueeze(2),
|
1168
|
+
vcache=past_value_state.unsqueeze(2),
|
1169
|
+
cache_seq_len=seq_position[0],
|
1170
|
+
cache_offset=seq_position[1],
|
1171
|
+
scale=scale,
|
1172
|
+
block_table=block_tables,
|
1173
|
+
block_size=block_size,
|
1174
|
+
)
|
1175
|
+
else:
|
1176
|
+
attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_prefill(
|
1177
|
+
q=query_state,
|
1178
|
+
k=key_state,
|
1179
|
+
v=value_state,
|
1180
|
+
kcache=past_key_state.unsqueeze(2),
|
1181
|
+
vcache=past_value_state.unsqueeze(2),
|
1182
|
+
cache_seq_len=seq_position[0],
|
1183
|
+
cache_offset=seq_position[1],
|
1184
|
+
scale=scale,
|
1185
|
+
block_table=block_tables,
|
1186
|
+
block_size=block_size,
|
1187
|
+
is_bidirectional=True if self.phase == "image_prefill" else False,
|
1188
|
+
)
|
1189
|
+
|
1190
|
+
# reshape for removing repeat_kv
|
1191
|
+
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
1192
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
1193
|
+
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
1194
|
+
|
1195
|
+
return attn_output
|
@@ -167,6 +167,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
167
167
|
block_tables: Optional[torch.Tensor] = None,
|
168
168
|
position_embed: Optional[torch.Tensor] = None,
|
169
169
|
position_ids: Optional[torch.Tensor] = None,
|
170
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
171
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
170
172
|
):
|
171
173
|
if input_ids is None and inputs_embeds is None:
|
172
174
|
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
@@ -193,6 +195,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
193
195
|
attention_mask=attention_mask,
|
194
196
|
position_embed=position_embed,
|
195
197
|
position_ids=position_ids,
|
198
|
+
local_block_tables=local_block_tables,
|
196
199
|
)
|
197
200
|
else:
|
198
201
|
return self.prefill_forward(
|
@@ -202,6 +205,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
202
205
|
batch_idx,
|
203
206
|
block_tables,
|
204
207
|
position_embed=position_embed,
|
208
|
+
token_type_ids=token_type_ids,
|
209
|
+
local_block_tables=local_block_tables,
|
205
210
|
)
|
206
211
|
|
207
212
|
def decode_forward(
|
@@ -213,6 +218,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
213
218
|
attention_mask: Optional[torch.Tensor] = None,
|
214
219
|
position_embed: Optional[torch.Tensor] = None,
|
215
220
|
position_ids: Optional[torch.Tensor] = None,
|
221
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
216
222
|
) -> torch.FloatTensor:
|
217
223
|
batch_size = inputs.shape[0]
|
218
224
|
if batch_size != self.batch_size:
|
@@ -262,6 +268,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
262
268
|
cache_position: torch.Tensor,
|
263
269
|
attention_mask: Optional[torch.Tensor] = None,
|
264
270
|
position_embed: Optional[torch.Tensor] = None,
|
271
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
272
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
265
273
|
):
|
266
274
|
"""
|
267
275
|
Prepare inputs for prefill phase.
|
@@ -345,6 +353,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
345
353
|
block_tables: torch.Tensor = None,
|
346
354
|
is_external_block_tables: bool = None,
|
347
355
|
position_embed: Optional[torch.Tensor] = None,
|
356
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
357
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
348
358
|
) -> torch.FloatTensor:
|
349
359
|
"""
|
350
360
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
@@ -360,7 +370,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
360
370
|
position_embed,
|
361
371
|
padded_cache_lengths,
|
362
372
|
query_length,
|
363
|
-
) = self._prepare_prefill_inputs(
|
373
|
+
) = self._prepare_prefill_inputs(
|
374
|
+
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
375
|
+
)
|
364
376
|
|
365
377
|
# Process input in chunks of size `prefill_chunk_size`
|
366
378
|
for step in range(0, query_length, self.prefill_chunk_size):
|
@@ -373,7 +385,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
373
385
|
if position_embed is not None:
|
374
386
|
position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
|
375
387
|
|
376
|
-
if self.use_attention_mask:
|
388
|
+
if self.use_attention_mask and not self.use_position_ids:
|
377
389
|
# Update attention mask to ensure proper causal behavior
|
378
390
|
if step >= self.prefill_chunk_size:
|
379
391
|
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
@@ -387,10 +399,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
387
399
|
input_chunk,
|
388
400
|
cache_pos_chunk,
|
389
401
|
block_tables,
|
402
|
+
position_embed_chunk if position_embed is not None else None,
|
390
403
|
query_position,
|
391
404
|
chunked_attention_mask if self.use_attention_mask else None,
|
392
|
-
position_ids_chunk if
|
393
|
-
position_embed_chunk if position_embed is not None else None,
|
405
|
+
position_ids_chunk if self.use_position_ids else None,
|
394
406
|
out=out_buffers,
|
395
407
|
)
|
396
408
|
|
@@ -440,12 +452,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
440
452
|
if self.rbln_config.use_inputs_embeds:
|
441
453
|
main_input_name = "inputs_embeds"
|
442
454
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
443
|
-
|
444
|
-
self.embed_tokens = torch.nn.Embedding(
|
445
|
-
self.config.vocab_size,
|
446
|
-
self.config.hidden_size,
|
447
|
-
self.config.pad_token_id,
|
448
|
-
)
|
455
|
+
self.embed_tokens = self._create_embedding_layer()
|
449
456
|
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
450
457
|
else:
|
451
458
|
self.embed_tokens = None
|
@@ -478,6 +485,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
478
485
|
attn_impl=self.rbln_config.attn_impl,
|
479
486
|
use_position_ids=self.rbln_config.use_position_ids,
|
480
487
|
)
|
488
|
+
|
481
489
|
self.decoders = {}
|
482
490
|
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
483
491
|
self.decoders[batch_size] = RBLNRuntimeModel(
|
@@ -515,6 +523,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
515
523
|
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
516
524
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
517
525
|
|
526
|
+
def _create_embedding_layer(self):
|
527
|
+
with no_init_weights():
|
528
|
+
embed_tokens = torch.nn.Embedding(
|
529
|
+
self.config.vocab_size,
|
530
|
+
self.config.hidden_size,
|
531
|
+
self.config.pad_token_id,
|
532
|
+
)
|
533
|
+
return embed_tokens
|
534
|
+
|
518
535
|
def get_input_embeddings(self):
|
519
536
|
return self.embed_tokens
|
520
537
|
|
@@ -1101,6 +1118,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1101
1118
|
generate_idx: Optional[torch.Tensor] = None,
|
1102
1119
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
1103
1120
|
position_ids: Optional[torch.Tensor] = None,
|
1121
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
1104
1122
|
return_dict: Optional[torch.Tensor] = None,
|
1105
1123
|
**kwargs,
|
1106
1124
|
) -> Tuple[torch.FloatTensor]:
|
@@ -1123,6 +1141,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1123
1141
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
1124
1142
|
cache_position=cache_position,
|
1125
1143
|
batch_idx=b_idx,
|
1144
|
+
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
|
1126
1145
|
)
|
1127
1146
|
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
1128
1147
|
logits.append(output.logits)
|
@@ -41,7 +41,10 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
41
41
|
for layer in causal_lm.transformer.h:
|
42
42
|
if self.attn_impl == "eager":
|
43
43
|
new_self_attn = ExaoneAttention(
|
44
|
-
layer.attn.attention,
|
44
|
+
layer.attn.attention,
|
45
|
+
self.use_attention_mask,
|
46
|
+
kvcache_block_size=self.kvcache_block_size,
|
47
|
+
use_position_ids=self.use_position_ids,
|
45
48
|
)
|
46
49
|
elif self.attn_impl == "flash_attn":
|
47
50
|
new_self_attn = ExaoneFlashAttention(
|
@@ -49,6 +52,7 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
49
52
|
kvcache_partition_len=self.kvcache_partition_len,
|
50
53
|
use_attention_mask=self.use_attention_mask,
|
51
54
|
kvcache_block_size=self.kvcache_block_size,
|
55
|
+
use_position_ids=self.use_position_ids,
|
52
56
|
)
|
53
57
|
else:
|
54
58
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -34,7 +34,10 @@ class GemmaWrapper(DecoderOnlyWrapper):
|
|
34
34
|
for layer in causal_lm.model.layers:
|
35
35
|
if self.attn_impl == "eager":
|
36
36
|
new_self_attn = DecoderOnlyAttention(
|
37
|
-
layer.self_attn,
|
37
|
+
layer.self_attn,
|
38
|
+
self.use_attention_mask,
|
39
|
+
kvcache_block_size=self.kvcache_block_size,
|
40
|
+
use_position_ids=self.use_position_ids,
|
38
41
|
)
|
39
42
|
elif self.attn_impl == "flash_attn":
|
40
43
|
new_self_attn = DecoderOnlyFlashAttention(
|
@@ -42,6 +45,7 @@ class GemmaWrapper(DecoderOnlyWrapper):
|
|
42
45
|
kvcache_partition_len=self.kvcache_partition_len,
|
43
46
|
use_attention_mask=self.use_attention_mask,
|
44
47
|
kvcache_block_size=self.kvcache_block_size,
|
48
|
+
use_position_ids=self.use_position_ids,
|
45
49
|
)
|
46
50
|
else:
|
47
51
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig, RBLNGemma3ForConditionalGenerationConfig
|
16
|
+
from .modeling_gemma3 import RBLNGemma3ForCausalLM, RBLNGemma3ForConditionalGeneration
|