optimum-rbln 0.9.3__py3-none-any.whl → 0.9.4a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__version__.py +2 -2
  2. optimum/rbln/configuration_utils.py +12 -4
  3. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  4. optimum/rbln/diffusers/models/controlnet.py +1 -1
  5. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -1
  6. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  11. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  12. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  13. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  14. optimum/rbln/modeling_base.py +12 -7
  15. optimum/rbln/transformers/modeling_attention_utils.py +4 -4
  16. optimum/rbln/transformers/modeling_outputs.py +1 -0
  17. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  18. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  19. optimum/rbln/transformers/models/colpali/modeling_colpali.py +1 -1
  20. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
  21. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +4 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  23. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +201 -62
  25. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +106 -36
  27. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
  28. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  29. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +43 -26
  30. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
  31. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +0 -1
  32. optimum/rbln/transformers/models/llava/modeling_llava.py +1 -1
  33. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -1
  34. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  35. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  36. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  37. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -6
  38. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +4 -4
  39. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
  40. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  41. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +2 -2
  42. optimum/rbln/transformers/models/swin/modeling_swin.py +3 -3
  43. optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
  44. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +9 -8
  45. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -2
  46. optimum/rbln/utils/import_utils.py +7 -1
  47. optimum/rbln/utils/submodule.py +3 -1
  48. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +1 -1
  49. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +52 -52
  50. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +0 -0
  51. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
  52. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
@@ -21,7 +21,6 @@ from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
22
  from ....utils import logging
23
23
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
24
- from ...utils.rbln_quantization import RBLNQuantizationConfig
25
24
  from .configuration_lora import RBLNLoRAConfig
26
25
  from .lora_architecture import LoRALinear
27
26
 
@@ -77,7 +76,7 @@ class DecoderOnlyWrapper(nn.Module):
77
76
  )
78
77
 
79
78
  self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
80
- self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
79
+ self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or self.config.n_layer
81
80
  self._phase = "prefill"
82
81
 
83
82
  def get_rotary_emb(self, max_seq_len):
@@ -203,7 +202,7 @@ class DecoderOnlyWrapper(nn.Module):
203
202
  rotary_emb,
204
203
  ) = self.prepare_forward_args(*args)
205
204
 
206
- logit = self.model(
205
+ logits, all_hidden_states = self.model(
207
206
  input_ids=input_ids,
208
207
  inputs_embeds=inputs_embeds,
209
208
  attention_mask=attention_mask,
@@ -215,9 +214,13 @@ class DecoderOnlyWrapper(nn.Module):
215
214
  global_block_tables=global_block_tables,
216
215
  local_block_tables=local_block_tables,
217
216
  lora_int_id=lora_int_id,
217
+ output_hidden_states=self.rbln_config.output_hidden_states,
218
218
  )
219
219
 
220
- return logit
220
+ if self.rbln_config.output_hidden_states:
221
+ return logits, all_hidden_states
222
+ else:
223
+ return logits
221
224
 
222
225
 
223
226
  class DecoderOnlyForCausalLM(nn.Module):
@@ -272,9 +275,10 @@ class DecoderOnlyForCausalLM(nn.Module):
272
275
  global_block_tables: Optional[torch.Tensor] = None,
273
276
  local_block_tables: Optional[torch.Tensor] = None,
274
277
  lora_int_id: Optional[torch.Tensor] = None,
278
+ output_hidden_states: Optional[bool] = None,
275
279
  ):
276
280
  # outputs
277
- hidden_states = self.model(
281
+ hidden_states, all_hidden_states = self.model(
278
282
  input_ids=input_ids,
279
283
  inputs_embeds=inputs_embeds,
280
284
  attention_mask=attention_mask,
@@ -286,6 +290,7 @@ class DecoderOnlyForCausalLM(nn.Module):
286
290
  global_block_tables=global_block_tables,
287
291
  local_block_tables=local_block_tables,
288
292
  lora_int_id=lora_int_id,
293
+ output_hidden_states=output_hidden_states,
289
294
  )
290
295
 
291
296
  if "prefill" in self.phase:
@@ -299,7 +304,7 @@ class DecoderOnlyForCausalLM(nn.Module):
299
304
  logits = torch.tanh(logits)
300
305
  logits = logits * self.config.final_logit_softcapping
301
306
 
302
- return logits
307
+ return logits, all_hidden_states
303
308
 
304
309
 
305
310
  class DecoderOnlyModel(nn.Module):
@@ -398,6 +403,7 @@ class DecoderOnlyModel(nn.Module):
398
403
  global_block_tables: Optional[torch.Tensor] = None,
399
404
  local_block_tables: Optional[torch.Tensor] = None,
400
405
  lora_int_id: Optional[torch.Tensor] = None,
406
+ output_hidden_states: Optional[bool] = None,
401
407
  ):
402
408
  # retrieve input_ids and inputs_embeds
403
409
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -460,7 +466,11 @@ class DecoderOnlyModel(nn.Module):
460
466
  if len(self.sliding_window_layers) > 0:
461
467
  sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
462
468
 
469
+ all_hidden_states = () if output_hidden_states else None
463
470
  for layer_idx, layer in enumerate(self.layers):
471
+ if output_hidden_states:
472
+ all_hidden_states += (hidden_states,)
473
+
464
474
  is_sliding = True if layer_idx in self.sliding_window_layers else False
465
475
  hidden_states = layer(
466
476
  hidden_states=hidden_states,
@@ -474,7 +484,10 @@ class DecoderOnlyModel(nn.Module):
474
484
  )
475
485
 
476
486
  hidden_states = self.get_last_layernorm()(hidden_states)
477
- return hidden_states
487
+ if output_hidden_states:
488
+ all_hidden_states += (hidden_states,)
489
+
490
+ return hidden_states, all_hidden_states
478
491
 
479
492
 
480
493
  class DecoderOnlyLayer(nn.Module):
@@ -616,13 +629,12 @@ class DecoderOnlyAttention(nn.Module):
616
629
  self._original_mod = self_attn
617
630
  self.rbln_config = rbln_config
618
631
  self.layer_idx = self_attn.layer_idx
619
- self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
620
- self._original_mod.config, "num_attention_heads"
632
+ self.num_heads = (
633
+ getattr(self._original_mod, "num_heads", None) or self._original_mod.config.num_attention_heads
621
634
  )
622
635
  self.head_dim = self._original_mod.head_dim
623
636
  self._phase = "prefill"
624
637
  self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
625
- self.quantization = rbln_config.quantization
626
638
 
627
639
  if hasattr(self._original_mod, "num_key_value_heads"):
628
640
  self.num_key_value_heads = self._original_mod.num_key_value_heads
@@ -631,8 +643,6 @@ class DecoderOnlyAttention(nn.Module):
631
643
  else:
632
644
  self.num_key_value_heads = self.num_heads
633
645
 
634
- self.use_attention_mask = rbln_config.use_attention_mask if not is_sliding else True
635
- self.use_position_ids = rbln_config.use_position_ids
636
646
  self.is_sliding = is_sliding
637
647
  self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
638
648
  self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
@@ -680,8 +690,7 @@ class DecoderOnlyAttention(nn.Module):
680
690
  self.num_heads,
681
691
  self.head_dim,
682
692
  self.num_key_value_heads,
683
- self.use_attention_mask,
684
- self.use_position_ids,
693
+ rbln_config=self.rbln_config,
685
694
  )
686
695
  elif self.attn_impl == "flash_attn":
687
696
  return FlashAttentionOp(
@@ -689,18 +698,16 @@ class DecoderOnlyAttention(nn.Module):
689
698
  self.head_dim,
690
699
  self.num_key_value_heads,
691
700
  self.kvcache_partition_len,
692
- self.use_attention_mask,
693
- self.use_position_ids,
694
- self.quantization,
701
+ rbln_config=self.rbln_config,
702
+ is_sliding=False,
695
703
  )
696
704
  elif self.attn_impl == "eager":
697
705
  return AttentionOp(
698
706
  self.num_heads,
699
707
  self.head_dim,
700
708
  self.num_key_value_heads,
701
- self.use_attention_mask,
702
- self.use_position_ids,
703
- self.quantization,
709
+ rbln_config=self.rbln_config,
710
+ is_sliding=False,
704
711
  )
705
712
  else:
706
713
  raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
@@ -830,23 +837,27 @@ class AttentionOp(nn.Module):
830
837
  num_heads: int,
831
838
  head_dim: int,
832
839
  num_key_value_heads: int,
833
- use_attention_mask: bool,
834
- use_position_ids: bool,
835
- quantization: Optional[RBLNQuantizationConfig] = None,
840
+ rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
841
+ is_sliding: bool = False,
836
842
  ):
837
843
  super().__init__()
838
844
  self.num_heads = num_heads
839
845
  self.head_dim = head_dim
840
846
  self.num_key_value_heads = num_key_value_heads
841
847
  self.phase = "prefill"
842
- self.use_attention_mask = use_attention_mask
843
- self.use_position_ids = use_position_ids
844
- self.quantization = quantization
848
+ self.rbln_config = rbln_config
849
+ self.use_attention_mask = True if is_sliding else rbln_config.use_attention_mask
850
+ self.use_position_ids = rbln_config.use_position_ids
851
+ self.quantization = rbln_config.quantization
845
852
 
846
853
  def get_attn_op_name(self):
847
854
  phase = "decode" if self.phase == "decode" else "prefill"
848
- if self.use_attention_mask and not self.use_position_ids:
849
- attn_op_name = "paged_attn_"
855
+
856
+ if self.use_attention_mask:
857
+ if self.rbln_config.use_position_ids:
858
+ attn_op_name = "paged_causal_attn_"
859
+ else:
860
+ attn_op_name = "paged_attn_"
850
861
  else:
851
862
  attn_op_name = "paged_causal_attn_"
852
863
 
@@ -895,7 +906,7 @@ class AttentionOp(nn.Module):
895
906
  key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
896
907
  value_state = value_state.unsqueeze(2)
897
908
 
898
- if self.use_attention_mask and not self.use_position_ids:
909
+ if self.use_attention_mask and not self.rbln_config.use_position_ids:
899
910
  attn_mask = attn_mask.unsqueeze(2)
900
911
 
901
912
  if self.phase == "decode":
@@ -927,8 +938,14 @@ class AttentionOp(nn.Module):
927
938
  op_args["mask"] = attn_mask
928
939
 
929
940
  if self.phase == "prefill" or self.phase == "image_prefill":
930
- if not self.use_attention_mask or self.use_position_ids:
931
- op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
941
+ use_image_prefill = getattr(self.rbln_config, "use_image_prefill", False)
942
+ if use_image_prefill:
943
+ op_args["is_bidirectional"] = self.phase == "image_prefill"
944
+ else:
945
+ if not self.use_attention_mask:
946
+ op_args["is_bidirectional"] = False
947
+ elif self.use_attention_mask and self.rbln_config.use_position_ids:
948
+ op_args["is_bidirectional"] = True
932
949
 
933
950
  if self.quantization and self.quantization.kv_caches == "fp8":
934
951
  if past_key_state.dtype != torch.float8_e4m3fn:
@@ -956,24 +973,26 @@ class FlashAttentionOp(AttentionOp):
956
973
  head_dim: int,
957
974
  num_key_value_heads: int,
958
975
  kvcache_partition_len: int,
959
- use_attention_mask: bool,
960
- use_position_ids: bool,
961
- quantization: Optional[RBLNQuantizationConfig] = None,
976
+ rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
977
+ is_sliding: bool = False,
962
978
  ):
963
979
  super().__init__(
964
980
  num_heads=num_heads,
965
981
  head_dim=head_dim,
966
982
  num_key_value_heads=num_key_value_heads,
967
- use_attention_mask=use_attention_mask,
968
- use_position_ids=use_position_ids,
969
- quantization=quantization,
983
+ rbln_config=rbln_config,
984
+ is_sliding=is_sliding,
970
985
  )
971
986
  self.kvcache_partition_size = kvcache_partition_len
972
987
 
973
988
  def get_attn_op_name(self):
974
989
  phase = "decode" if self.phase == "decode" else "prefill"
975
- if self.use_attention_mask and not self.use_position_ids:
976
- attn_op_name = "paged_flash_attn_"
990
+
991
+ if self.use_attention_mask:
992
+ if self.rbln_config.use_position_ids:
993
+ attn_op_name = "paged_flash_causal_attn_"
994
+ else:
995
+ attn_op_name = "paged_flash_attn_"
977
996
  else:
978
997
  attn_op_name = "paged_flash_causal_attn_"
979
998
 
@@ -1002,7 +1021,8 @@ class FlashAttentionOp(AttentionOp):
1002
1021
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1003
1022
  key_state = key_state.unsqueeze(2)
1004
1023
  value_state = value_state.unsqueeze(2)
1005
- if self.use_attention_mask and not self.use_position_ids:
1024
+
1025
+ if self.use_attention_mask and not self.rbln_config.use_position_ids:
1006
1026
  attn_mask = attn_mask.unsqueeze(2)
1007
1027
 
1008
1028
  if self.phase == "decode":
@@ -1035,8 +1055,14 @@ class FlashAttentionOp(AttentionOp):
1035
1055
  op_args["mask"] = attn_mask
1036
1056
 
1037
1057
  if self.phase == "prefill" or self.phase == "image_prefill":
1038
- if not self.use_attention_mask or self.use_position_ids:
1039
- op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1058
+ use_image_prefill = getattr(self.rbln_config, "use_image_prefill", False)
1059
+ if use_image_prefill:
1060
+ op_args["is_bidirectional"] = self.phase == "image_prefill"
1061
+ else:
1062
+ if not self.use_attention_mask:
1063
+ op_args["is_bidirectional"] = False
1064
+ elif self.use_attention_mask and self.rbln_config.use_position_ids:
1065
+ op_args["is_bidirectional"] = True
1040
1066
 
1041
1067
  if self.quantization and self.quantization.kv_caches == "fp8":
1042
1068
  if past_key_state.dtype != torch.float8_e4m3fn:
@@ -1058,6 +1084,22 @@ class FlashAttentionOp(AttentionOp):
1058
1084
 
1059
1085
 
1060
1086
  class SlidingWindowAttentionOp(AttentionOp):
1087
+ def __init__(
1088
+ self,
1089
+ num_heads: int,
1090
+ head_dim: int,
1091
+ num_key_value_heads: int,
1092
+ rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
1093
+ ):
1094
+ super().__init__(
1095
+ num_heads=num_heads,
1096
+ head_dim=head_dim,
1097
+ num_key_value_heads=num_key_value_heads,
1098
+ rbln_config=rbln_config,
1099
+ is_sliding=True,
1100
+ )
1101
+ self.quantization = None # Sliding window attention does not support quantization
1102
+
1061
1103
  def get_attn_op_name(self):
1062
1104
  phase = "decode" if self.phase == "decode" else "prefill"
1063
1105
  if not self.use_attention_mask:
@@ -1115,7 +1157,14 @@ class SlidingWindowAttentionOp(AttentionOp):
1115
1157
  }
1116
1158
 
1117
1159
  if self.phase == "prefill" or self.phase == "image_prefill":
1118
- op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1160
+ use_image_prefill = getattr(self.rbln_config, "use_image_prefill", False)
1161
+ if use_image_prefill:
1162
+ op_args["is_bidirectional"] = self.phase == "image_prefill"
1163
+ else:
1164
+ if self.use_attention_mask and self.rbln_config.use_position_ids:
1165
+ op_args["is_bidirectional"] = True
1166
+ else:
1167
+ op_args["is_bidirectional"] = False
1119
1168
 
1120
1169
  attn_op_name = self.get_attn_op_name()
1121
1170
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)