optimum-rbln 0.9.3__py3-none-any.whl → 0.9.4a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +12 -4
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -1
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/modeling_base.py +12 -7
- optimum/rbln/transformers/modeling_attention_utils.py +4 -4
- optimum/rbln/transformers/modeling_outputs.py +1 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +1 -1
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +4 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +201 -62
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +106 -36
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +43 -26
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +0 -1
- optimum/rbln/transformers/models/llava/modeling_llava.py +1 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -6
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +4 -4
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +2 -2
- optimum/rbln/transformers/models/swin/modeling_swin.py +3 -3
- optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +9 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -2
- optimum/rbln/utils/import_utils.py +7 -1
- optimum/rbln/utils/submodule.py +3 -1
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +1 -1
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +52 -52
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,7 +21,6 @@ from transformers import PretrainedConfig, PreTrainedModel
|
|
|
21
21
|
|
|
22
22
|
from ....utils import logging
|
|
23
23
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
24
|
-
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
25
24
|
from .configuration_lora import RBLNLoRAConfig
|
|
26
25
|
from .lora_architecture import LoRALinear
|
|
27
26
|
|
|
@@ -77,7 +76,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
77
76
|
)
|
|
78
77
|
|
|
79
78
|
self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
|
|
80
|
-
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or
|
|
79
|
+
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or self.config.n_layer
|
|
81
80
|
self._phase = "prefill"
|
|
82
81
|
|
|
83
82
|
def get_rotary_emb(self, max_seq_len):
|
|
@@ -203,7 +202,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
203
202
|
rotary_emb,
|
|
204
203
|
) = self.prepare_forward_args(*args)
|
|
205
204
|
|
|
206
|
-
|
|
205
|
+
logits, all_hidden_states = self.model(
|
|
207
206
|
input_ids=input_ids,
|
|
208
207
|
inputs_embeds=inputs_embeds,
|
|
209
208
|
attention_mask=attention_mask,
|
|
@@ -215,9 +214,13 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
215
214
|
global_block_tables=global_block_tables,
|
|
216
215
|
local_block_tables=local_block_tables,
|
|
217
216
|
lora_int_id=lora_int_id,
|
|
217
|
+
output_hidden_states=self.rbln_config.output_hidden_states,
|
|
218
218
|
)
|
|
219
219
|
|
|
220
|
-
|
|
220
|
+
if self.rbln_config.output_hidden_states:
|
|
221
|
+
return logits, all_hidden_states
|
|
222
|
+
else:
|
|
223
|
+
return logits
|
|
221
224
|
|
|
222
225
|
|
|
223
226
|
class DecoderOnlyForCausalLM(nn.Module):
|
|
@@ -272,9 +275,10 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
272
275
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
273
276
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
274
277
|
lora_int_id: Optional[torch.Tensor] = None,
|
|
278
|
+
output_hidden_states: Optional[bool] = None,
|
|
275
279
|
):
|
|
276
280
|
# outputs
|
|
277
|
-
hidden_states = self.model(
|
|
281
|
+
hidden_states, all_hidden_states = self.model(
|
|
278
282
|
input_ids=input_ids,
|
|
279
283
|
inputs_embeds=inputs_embeds,
|
|
280
284
|
attention_mask=attention_mask,
|
|
@@ -286,6 +290,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
286
290
|
global_block_tables=global_block_tables,
|
|
287
291
|
local_block_tables=local_block_tables,
|
|
288
292
|
lora_int_id=lora_int_id,
|
|
293
|
+
output_hidden_states=output_hidden_states,
|
|
289
294
|
)
|
|
290
295
|
|
|
291
296
|
if "prefill" in self.phase:
|
|
@@ -299,7 +304,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
299
304
|
logits = torch.tanh(logits)
|
|
300
305
|
logits = logits * self.config.final_logit_softcapping
|
|
301
306
|
|
|
302
|
-
return logits
|
|
307
|
+
return logits, all_hidden_states
|
|
303
308
|
|
|
304
309
|
|
|
305
310
|
class DecoderOnlyModel(nn.Module):
|
|
@@ -398,6 +403,7 @@ class DecoderOnlyModel(nn.Module):
|
|
|
398
403
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
399
404
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
400
405
|
lora_int_id: Optional[torch.Tensor] = None,
|
|
406
|
+
output_hidden_states: Optional[bool] = None,
|
|
401
407
|
):
|
|
402
408
|
# retrieve input_ids and inputs_embeds
|
|
403
409
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
@@ -460,7 +466,11 @@ class DecoderOnlyModel(nn.Module):
|
|
|
460
466
|
if len(self.sliding_window_layers) > 0:
|
|
461
467
|
sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
|
|
462
468
|
|
|
469
|
+
all_hidden_states = () if output_hidden_states else None
|
|
463
470
|
for layer_idx, layer in enumerate(self.layers):
|
|
471
|
+
if output_hidden_states:
|
|
472
|
+
all_hidden_states += (hidden_states,)
|
|
473
|
+
|
|
464
474
|
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
|
465
475
|
hidden_states = layer(
|
|
466
476
|
hidden_states=hidden_states,
|
|
@@ -474,7 +484,10 @@ class DecoderOnlyModel(nn.Module):
|
|
|
474
484
|
)
|
|
475
485
|
|
|
476
486
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
477
|
-
|
|
487
|
+
if output_hidden_states:
|
|
488
|
+
all_hidden_states += (hidden_states,)
|
|
489
|
+
|
|
490
|
+
return hidden_states, all_hidden_states
|
|
478
491
|
|
|
479
492
|
|
|
480
493
|
class DecoderOnlyLayer(nn.Module):
|
|
@@ -616,13 +629,12 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
616
629
|
self._original_mod = self_attn
|
|
617
630
|
self.rbln_config = rbln_config
|
|
618
631
|
self.layer_idx = self_attn.layer_idx
|
|
619
|
-
self.num_heads =
|
|
620
|
-
self._original_mod
|
|
632
|
+
self.num_heads = (
|
|
633
|
+
getattr(self._original_mod, "num_heads", None) or self._original_mod.config.num_attention_heads
|
|
621
634
|
)
|
|
622
635
|
self.head_dim = self._original_mod.head_dim
|
|
623
636
|
self._phase = "prefill"
|
|
624
637
|
self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
|
|
625
|
-
self.quantization = rbln_config.quantization
|
|
626
638
|
|
|
627
639
|
if hasattr(self._original_mod, "num_key_value_heads"):
|
|
628
640
|
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
|
@@ -631,8 +643,6 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
631
643
|
else:
|
|
632
644
|
self.num_key_value_heads = self.num_heads
|
|
633
645
|
|
|
634
|
-
self.use_attention_mask = rbln_config.use_attention_mask if not is_sliding else True
|
|
635
|
-
self.use_position_ids = rbln_config.use_position_ids
|
|
636
646
|
self.is_sliding = is_sliding
|
|
637
647
|
self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
|
|
638
648
|
self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
|
|
@@ -680,8 +690,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
680
690
|
self.num_heads,
|
|
681
691
|
self.head_dim,
|
|
682
692
|
self.num_key_value_heads,
|
|
683
|
-
self.
|
|
684
|
-
self.use_position_ids,
|
|
693
|
+
rbln_config=self.rbln_config,
|
|
685
694
|
)
|
|
686
695
|
elif self.attn_impl == "flash_attn":
|
|
687
696
|
return FlashAttentionOp(
|
|
@@ -689,18 +698,16 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
689
698
|
self.head_dim,
|
|
690
699
|
self.num_key_value_heads,
|
|
691
700
|
self.kvcache_partition_len,
|
|
692
|
-
self.
|
|
693
|
-
|
|
694
|
-
self.quantization,
|
|
701
|
+
rbln_config=self.rbln_config,
|
|
702
|
+
is_sliding=False,
|
|
695
703
|
)
|
|
696
704
|
elif self.attn_impl == "eager":
|
|
697
705
|
return AttentionOp(
|
|
698
706
|
self.num_heads,
|
|
699
707
|
self.head_dim,
|
|
700
708
|
self.num_key_value_heads,
|
|
701
|
-
self.
|
|
702
|
-
|
|
703
|
-
self.quantization,
|
|
709
|
+
rbln_config=self.rbln_config,
|
|
710
|
+
is_sliding=False,
|
|
704
711
|
)
|
|
705
712
|
else:
|
|
706
713
|
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
@@ -830,23 +837,27 @@ class AttentionOp(nn.Module):
|
|
|
830
837
|
num_heads: int,
|
|
831
838
|
head_dim: int,
|
|
832
839
|
num_key_value_heads: int,
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
quantization: Optional[RBLNQuantizationConfig] = None,
|
|
840
|
+
rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
|
|
841
|
+
is_sliding: bool = False,
|
|
836
842
|
):
|
|
837
843
|
super().__init__()
|
|
838
844
|
self.num_heads = num_heads
|
|
839
845
|
self.head_dim = head_dim
|
|
840
846
|
self.num_key_value_heads = num_key_value_heads
|
|
841
847
|
self.phase = "prefill"
|
|
842
|
-
self.
|
|
843
|
-
self.
|
|
844
|
-
self.
|
|
848
|
+
self.rbln_config = rbln_config
|
|
849
|
+
self.use_attention_mask = True if is_sliding else rbln_config.use_attention_mask
|
|
850
|
+
self.use_position_ids = rbln_config.use_position_ids
|
|
851
|
+
self.quantization = rbln_config.quantization
|
|
845
852
|
|
|
846
853
|
def get_attn_op_name(self):
|
|
847
854
|
phase = "decode" if self.phase == "decode" else "prefill"
|
|
848
|
-
|
|
849
|
-
|
|
855
|
+
|
|
856
|
+
if self.use_attention_mask:
|
|
857
|
+
if self.rbln_config.use_position_ids:
|
|
858
|
+
attn_op_name = "paged_causal_attn_"
|
|
859
|
+
else:
|
|
860
|
+
attn_op_name = "paged_attn_"
|
|
850
861
|
else:
|
|
851
862
|
attn_op_name = "paged_causal_attn_"
|
|
852
863
|
|
|
@@ -895,7 +906,7 @@ class AttentionOp(nn.Module):
|
|
|
895
906
|
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
|
896
907
|
value_state = value_state.unsqueeze(2)
|
|
897
908
|
|
|
898
|
-
if self.use_attention_mask and not self.use_position_ids:
|
|
909
|
+
if self.use_attention_mask and not self.rbln_config.use_position_ids:
|
|
899
910
|
attn_mask = attn_mask.unsqueeze(2)
|
|
900
911
|
|
|
901
912
|
if self.phase == "decode":
|
|
@@ -927,8 +938,14 @@ class AttentionOp(nn.Module):
|
|
|
927
938
|
op_args["mask"] = attn_mask
|
|
928
939
|
|
|
929
940
|
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
930
|
-
|
|
931
|
-
|
|
941
|
+
use_image_prefill = getattr(self.rbln_config, "use_image_prefill", False)
|
|
942
|
+
if use_image_prefill:
|
|
943
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill"
|
|
944
|
+
else:
|
|
945
|
+
if not self.use_attention_mask:
|
|
946
|
+
op_args["is_bidirectional"] = False
|
|
947
|
+
elif self.use_attention_mask and self.rbln_config.use_position_ids:
|
|
948
|
+
op_args["is_bidirectional"] = True
|
|
932
949
|
|
|
933
950
|
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
934
951
|
if past_key_state.dtype != torch.float8_e4m3fn:
|
|
@@ -956,24 +973,26 @@ class FlashAttentionOp(AttentionOp):
|
|
|
956
973
|
head_dim: int,
|
|
957
974
|
num_key_value_heads: int,
|
|
958
975
|
kvcache_partition_len: int,
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
quantization: Optional[RBLNQuantizationConfig] = None,
|
|
976
|
+
rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
|
|
977
|
+
is_sliding: bool = False,
|
|
962
978
|
):
|
|
963
979
|
super().__init__(
|
|
964
980
|
num_heads=num_heads,
|
|
965
981
|
head_dim=head_dim,
|
|
966
982
|
num_key_value_heads=num_key_value_heads,
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
quantization=quantization,
|
|
983
|
+
rbln_config=rbln_config,
|
|
984
|
+
is_sliding=is_sliding,
|
|
970
985
|
)
|
|
971
986
|
self.kvcache_partition_size = kvcache_partition_len
|
|
972
987
|
|
|
973
988
|
def get_attn_op_name(self):
|
|
974
989
|
phase = "decode" if self.phase == "decode" else "prefill"
|
|
975
|
-
|
|
976
|
-
|
|
990
|
+
|
|
991
|
+
if self.use_attention_mask:
|
|
992
|
+
if self.rbln_config.use_position_ids:
|
|
993
|
+
attn_op_name = "paged_flash_causal_attn_"
|
|
994
|
+
else:
|
|
995
|
+
attn_op_name = "paged_flash_attn_"
|
|
977
996
|
else:
|
|
978
997
|
attn_op_name = "paged_flash_causal_attn_"
|
|
979
998
|
|
|
@@ -1002,7 +1021,8 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1002
1021
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1003
1022
|
key_state = key_state.unsqueeze(2)
|
|
1004
1023
|
value_state = value_state.unsqueeze(2)
|
|
1005
|
-
|
|
1024
|
+
|
|
1025
|
+
if self.use_attention_mask and not self.rbln_config.use_position_ids:
|
|
1006
1026
|
attn_mask = attn_mask.unsqueeze(2)
|
|
1007
1027
|
|
|
1008
1028
|
if self.phase == "decode":
|
|
@@ -1035,8 +1055,14 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1035
1055
|
op_args["mask"] = attn_mask
|
|
1036
1056
|
|
|
1037
1057
|
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1038
|
-
|
|
1039
|
-
|
|
1058
|
+
use_image_prefill = getattr(self.rbln_config, "use_image_prefill", False)
|
|
1059
|
+
if use_image_prefill:
|
|
1060
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill"
|
|
1061
|
+
else:
|
|
1062
|
+
if not self.use_attention_mask:
|
|
1063
|
+
op_args["is_bidirectional"] = False
|
|
1064
|
+
elif self.use_attention_mask and self.rbln_config.use_position_ids:
|
|
1065
|
+
op_args["is_bidirectional"] = True
|
|
1040
1066
|
|
|
1041
1067
|
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
1042
1068
|
if past_key_state.dtype != torch.float8_e4m3fn:
|
|
@@ -1058,6 +1084,22 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1058
1084
|
|
|
1059
1085
|
|
|
1060
1086
|
class SlidingWindowAttentionOp(AttentionOp):
|
|
1087
|
+
def __init__(
|
|
1088
|
+
self,
|
|
1089
|
+
num_heads: int,
|
|
1090
|
+
head_dim: int,
|
|
1091
|
+
num_key_value_heads: int,
|
|
1092
|
+
rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
|
|
1093
|
+
):
|
|
1094
|
+
super().__init__(
|
|
1095
|
+
num_heads=num_heads,
|
|
1096
|
+
head_dim=head_dim,
|
|
1097
|
+
num_key_value_heads=num_key_value_heads,
|
|
1098
|
+
rbln_config=rbln_config,
|
|
1099
|
+
is_sliding=True,
|
|
1100
|
+
)
|
|
1101
|
+
self.quantization = None # Sliding window attention does not support quantization
|
|
1102
|
+
|
|
1061
1103
|
def get_attn_op_name(self):
|
|
1062
1104
|
phase = "decode" if self.phase == "decode" else "prefill"
|
|
1063
1105
|
if not self.use_attention_mask:
|
|
@@ -1115,7 +1157,14 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1115
1157
|
}
|
|
1116
1158
|
|
|
1117
1159
|
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1118
|
-
|
|
1160
|
+
use_image_prefill = getattr(self.rbln_config, "use_image_prefill", False)
|
|
1161
|
+
if use_image_prefill:
|
|
1162
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill"
|
|
1163
|
+
else:
|
|
1164
|
+
if self.use_attention_mask and self.rbln_config.use_position_ids:
|
|
1165
|
+
op_args["is_bidirectional"] = True
|
|
1166
|
+
else:
|
|
1167
|
+
op_args["is_bidirectional"] = False
|
|
1119
1168
|
|
|
1120
1169
|
attn_op_name = self.get_attn_op_name()
|
|
1121
1170
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|