optimum-rbln 0.9.3__py3-none-any.whl → 0.9.4a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__version__.py +2 -2
  2. optimum/rbln/configuration_utils.py +12 -4
  3. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  4. optimum/rbln/diffusers/models/controlnet.py +1 -1
  5. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -1
  6. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  11. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  12. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  13. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  14. optimum/rbln/modeling_base.py +12 -7
  15. optimum/rbln/transformers/modeling_attention_utils.py +4 -4
  16. optimum/rbln/transformers/modeling_outputs.py +1 -0
  17. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  18. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  19. optimum/rbln/transformers/models/colpali/modeling_colpali.py +1 -1
  20. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
  21. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +4 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  23. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +201 -62
  25. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +106 -36
  27. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
  28. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  29. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +43 -26
  30. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
  31. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +0 -1
  32. optimum/rbln/transformers/models/llava/modeling_llava.py +1 -1
  33. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -1
  34. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  35. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  36. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  37. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -6
  38. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +4 -4
  39. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
  40. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  41. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +2 -2
  42. optimum/rbln/transformers/models/swin/modeling_swin.py +3 -3
  43. optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
  44. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +9 -8
  45. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -2
  46. optimum/rbln/utils/import_utils.py +7 -1
  47. optimum/rbln/utils/submodule.py +3 -1
  48. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +1 -1
  49. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +52 -52
  50. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +0 -0
  51. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
  52. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
@@ -88,8 +88,12 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
88
88
  def setup_runtime(self):
89
89
  # Initialize resources to be used across Runtime instances (prefill and decode phases)
90
90
  page_table_manager = RBLNPageTableManager(self.rbln_config)
91
- dec_attn_mask = torch.zeros(self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype)
92
- out_buffers = [torch.empty(self.prefill_output_size, dtype=self.dtype)]
91
+ if self.rbln_config.use_position_ids:
92
+ dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=self.dtype)
93
+ else:
94
+ dec_attn_mask = torch.zeros(
95
+ self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype
96
+ )
93
97
 
94
98
  common_kwargs = {
95
99
  "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
@@ -97,12 +101,13 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
97
101
  "dec_attn_mask": dec_attn_mask,
98
102
  "page_table_manager": page_table_manager,
99
103
  "rbln_config": self.rbln_config,
104
+ "config": self.config,
100
105
  }
101
106
  self.prefill_decoder = RBLNRuntimeModel(
102
107
  runtime=self.model[0],
103
108
  phase="prefill",
104
109
  batch_size=self.rbln_config.batch_size,
105
- out_buffers=out_buffers,
110
+ logits_last_dim=self.logits_last_dim,
106
111
  **common_kwargs,
107
112
  )
108
113
  if self.can_generate():
@@ -119,12 +124,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
119
124
  self.decoder = self.decoders[self.rbln_config.batch_size]
120
125
 
121
126
  @property
122
- def prefill_output_size(self):
123
- return (
124
- 1,
125
- self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
126
- self.config.hidden_size,
127
- )
127
+ def logits_last_dim(self):
128
+ return self.config.hidden_size
128
129
 
129
130
  @classmethod
130
131
  def get_quantized_model(
@@ -340,10 +341,10 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
340
341
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
341
342
  model_config: PretrainedConfig,
342
343
  ):
343
- num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
344
+ num_attention_heads = getattr(model_config, "n_head", None) or model_config.num_attention_heads
344
345
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
345
- num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
346
- hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
346
+ num_hidden_layers = getattr(model_config, "n_layer", None) or model_config.num_hidden_layers
347
+ hidden_size = getattr(model_config, "n_embd", None) or model_config.hidden_size
347
348
  head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
348
349
  is_prefill = query_length > 1
349
350
 
@@ -439,10 +440,22 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
439
440
  # Returns:
440
441
  # RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
441
442
 
442
- raise NotImplementedError(
443
- "Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
444
- "See method docstring for required configuration details."
443
+ rbln_config.sliding_window = model_config.sliding_window
444
+ sliding_window_layers = []
445
+
446
+ for i in range(model_config.num_hidden_layers):
447
+ if hasattr(model_config, "layer_types"):
448
+ if model_config.layer_types[i] == "sliding_attention":
449
+ sliding_window_layers.append(i)
450
+ else:
451
+ sliding_window_layers.append(i)
452
+
453
+ rbln_config.sliding_window_layers = sliding_window_layers
454
+
455
+ rbln_config.cache_impl = (
456
+ "sliding_window" if len(sliding_window_layers) == model_config.num_hidden_layers else "hybrid"
445
457
  )
458
+ return rbln_config
446
459
 
447
460
  @classmethod
448
461
  def _update_attention_config(
@@ -525,8 +538,13 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
525
538
  if rbln_config.max_seq_len is None:
526
539
  raise ValueError("`max_seq_len` should be specified.")
527
540
 
528
- if getattr(model_config, "sliding_window", None) is not None and getattr(
529
- model_config, "use_sliding_window", True
541
+ layer_types = getattr(model_config, "layer_types", None)
542
+ all_full_attention = layer_types is not None and all(t == "full_attention" for t in layer_types)
543
+
544
+ if (
545
+ getattr(model_config, "sliding_window", None) is not None
546
+ and getattr(model_config, "use_sliding_window", True)
547
+ and not all_full_attention
530
548
  ):
531
549
  rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
532
550
  if rbln_config.sliding_window is not None:
@@ -602,6 +620,9 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
602
620
  input_ids: Optional[torch.LongTensor] = None,
603
621
  inputs_embeds: Optional[torch.Tensor] = None,
604
622
  attention_mask: Optional[torch.LongTensor] = None,
623
+ position_ids: Optional[torch.Tensor] = None,
624
+ position_embed: Optional[torch.Tensor] = None,
625
+ output_hidden_states: Optional[bool] = None,
605
626
  **kwargs,
606
627
  ) -> BaseModelOutputWithPast:
607
628
  """
@@ -623,24 +644,50 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
623
644
  f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
624
645
  )
625
646
 
647
+ output_hidden_states = (
648
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
649
+ )
650
+ if output_hidden_states != self.rbln_config.output_hidden_states:
651
+ raise ValueError(
652
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
653
+ f"Please compile again with the correct argument."
654
+ )
655
+
626
656
  all_last_hidden_states = []
657
+ all_hidden_states = (
658
+ tuple(
659
+ torch.zeros(
660
+ self.rbln_config.batch_size,
661
+ inputs.shape[1],
662
+ self.config.hidden_size,
663
+ dtype=self.rbln_config.torch_dtype,
664
+ )
665
+ for _ in range(self.config.num_hidden_layers + 1)
666
+ )
667
+ if output_hidden_states
668
+ else None
669
+ )
627
670
  for b_idx in range(self.rbln_config.batch_size):
628
671
  query_length = (
629
672
  attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
630
673
  )
631
674
  cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
632
- last_hidden_states = self.prefill_decoder(
633
- inputs[b_idx : b_idx + 1],
675
+ outputs = self.prefill_decoder(
676
+ input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
677
+ inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
634
678
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
679
+ position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
635
680
  position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
636
681
  cache_position=cache_position,
637
682
  batch_idx=b_idx,
638
- ).logits
639
- all_last_hidden_states.append(last_hidden_states)
683
+ )
684
+ all_last_hidden_states.append(outputs.logits)
685
+ if self.rbln_config.output_hidden_states:
686
+ for l_idx in range(self.config.num_hidden_layers + 1):
687
+ all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
640
688
 
641
689
  last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
642
-
643
- return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
690
+ return BaseModelOutputWithPast(last_hidden_state=last_hidden_states, hidden_states=all_hidden_states)
644
691
 
645
692
 
646
693
  class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
@@ -666,12 +713,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
666
713
  auto_model_class = AutoModelForCausalLM
667
714
 
668
715
  @property
669
- def prefill_output_size(self):
670
- return (
671
- 1,
672
- self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
673
- self.config.vocab_size,
674
- )
716
+ def logits_last_dim(self):
717
+ return self.config.vocab_size
675
718
 
676
719
  @classmethod
677
720
  def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
@@ -736,6 +779,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
736
779
  token_type_ids: Optional[torch.Tensor] = None,
737
780
  lora_int_ids: Optional[torch.Tensor] = None,
738
781
  return_dict: Optional[torch.Tensor] = None,
782
+ output_hidden_states: Optional[bool] = None,
739
783
  **kwargs,
740
784
  ) -> Tuple[torch.FloatTensor]:
741
785
  # Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
@@ -759,6 +803,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
759
803
  )
760
804
  padded_cache_lengths = torch.zeros_like(generate_idx)
761
805
 
806
+ output_hidden_states = (
807
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
808
+ )
809
+ if output_hidden_states != self.rbln_config.output_hidden_states:
810
+ raise ValueError(
811
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
812
+ f"Please compile again with the correct argument."
813
+ )
814
+
762
815
  # Prefill
763
816
  if cache_position is None:
764
817
  logits = []
@@ -774,19 +827,31 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
774
827
  f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
775
828
  )
776
829
 
830
+ all_hidden_states = (
831
+ tuple(
832
+ torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.torch_dtype)
833
+ for _ in range(self.config.num_hidden_layers + 1)
834
+ )
835
+ if self.rbln_config.output_hidden_states
836
+ else None
837
+ )
777
838
  for b_idx in range(batch_size):
778
839
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
779
- output = self.prefill_decoder(
840
+ outputs = self.prefill_decoder(
780
841
  input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
781
842
  inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
782
843
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
844
+ position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
783
845
  cache_position=cache_position,
784
846
  batch_idx=b_idx,
785
847
  token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
786
848
  lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
787
849
  )
788
- padded_cache_lengths[b_idx] += output.padded_cache_lengths
789
- logits.append(output.logits)
850
+ padded_cache_lengths[b_idx] += outputs.padded_cache_lengths
851
+ logits.append(outputs.logits)
852
+ if self.rbln_config.output_hidden_states:
853
+ for l_idx in range(self.config.num_hidden_layers + 1):
854
+ all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
790
855
  logits = torch.cat(logits, dim=0)
791
856
  # Decoder
792
857
  else:
@@ -807,17 +872,22 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
807
872
  f"or `max_length` in the generation config."
808
873
  )
809
874
 
810
- logits = self.decoders[batch_size](
875
+ outputs = self.decoders[batch_size](
811
876
  input_ids=input_ids,
812
877
  inputs_embeds=inputs_embeds,
813
878
  cache_position=cache_position,
814
879
  position_ids=position_ids if self.rbln_config.use_position_ids else None,
815
880
  lora_int_ids=lora_int_ids,
816
- ).logits
881
+ )
882
+ logits = outputs.logits
883
+ all_hidden_states = outputs.hidden_states
817
884
 
818
885
  if not return_dict:
819
- return logits, generate_idx, padded_cache_lengths
886
+ return logits, generate_idx, padded_cache_lengths, all_hidden_states
820
887
  else:
821
888
  return RBLNDecoderOnlyOutput(
822
- logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
889
+ logits=logits,
890
+ generate_idx=generate_idx,
891
+ padded_cache_lengths=padded_cache_lengths,
892
+ hidden_states=all_hidden_states,
823
893
  )
@@ -64,6 +64,7 @@ class Gemma3TextModel(DecoderOnlyModel):
64
64
  global_block_tables: Optional[torch.Tensor] = None,
65
65
  local_block_tables: Optional[torch.Tensor] = None,
66
66
  lora_int_id: Optional[torch.Tensor] = None,
67
+ output_hidden_states: Optional[bool] = None,
67
68
  ):
68
69
  # retrieve input_ids and inputs_embeds
69
70
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -96,7 +97,10 @@ class Gemma3TextModel(DecoderOnlyModel):
96
97
 
97
98
  sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
98
99
 
100
+ all_hidden_states = () if output_hidden_states else None
99
101
  for layer_idx, layer in enumerate(self.layers):
102
+ if output_hidden_states:
103
+ all_hidden_states += (hidden_states,)
100
104
  is_sliding = True if layer_idx in self.sliding_window_layers else False
101
105
  hidden_states = layer(
102
106
  hidden_states=hidden_states,
@@ -110,7 +114,9 @@ class Gemma3TextModel(DecoderOnlyModel):
110
114
  )
111
115
 
112
116
  hidden_states = self.get_last_layernorm()(hidden_states)
113
- return hidden_states
117
+ if output_hidden_states:
118
+ all_hidden_states += (hidden_states,)
119
+ return hidden_states, all_hidden_states
114
120
 
115
121
 
116
122
  class Gemma3DecoderLayer(DecoderOnlyLayer):
@@ -16,7 +16,7 @@ from typing import Optional
16
16
  import rebel
17
17
  import torch
18
18
 
19
- from ...modeling_outputs import RBLNDecoderOnlyOutput, RBLNGemma3ForCausalLMOutput
19
+ from ...modeling_outputs import RBLNGemma3ForCausalLMOutput
20
20
  from ..decoderonly.decoderonly_runtime_utils import RBLNPytorchRuntime
21
21
  from ..decoderonly.modeling_decoderonly import RBLNRuntimeModel
22
22
 
@@ -26,7 +26,6 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
26
26
  super().__init__(*args, **kwargs)
27
27
  self.image_prefill = RBLNPytorchRuntime(image_prefill) # FIXME(taehoon)
28
28
  self.prefill = RBLNPytorchRuntime(self.runtime) if self.phase == "prefill" else None # FIXME
29
- self.decode = RBLNPytorchRuntime(self.runtime) if self.phase == "decode" else None
30
29
 
31
30
  def _prepare_prefill_inputs(self, *args, **kwargs):
32
31
  (
@@ -106,6 +105,8 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
106
105
  )
107
106
 
108
107
  step = 0
108
+ output_logits = []
109
+ all_hidden_states = [] if self.rbln_config.output_hidden_states else None
109
110
  while step < query_length:
110
111
  if self.rbln_config.use_image_prefill:
111
112
  # Check if the prefill chunk is an image prefill
@@ -146,7 +147,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
146
147
  query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
147
148
 
148
149
  if is_image_prefill:
149
- logits = self.image_prefill(
150
+ outputs = self.image_prefill(
150
151
  input_chunk,
151
152
  cache_pos_chunk,
152
153
  block_tables,
@@ -157,7 +158,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
157
158
  lora_int_ids if self.rbln_config.use_lora else None,
158
159
  )
159
160
  else:
160
- logits = self.prefill(
161
+ outputs = self.prefill(
161
162
  input_chunk,
162
163
  cache_pos_chunk,
163
164
  block_tables,
@@ -168,78 +169,49 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
168
169
  lora_int_ids if self.rbln_config.use_lora else None,
169
170
  )
170
171
 
172
+ if self.rbln_config.output_hidden_states:
173
+ output_logits.append(outputs[0])
174
+ all_hidden_states.append(tuple(outputs[1:]))
175
+ else:
176
+ output_logits.append(outputs)
177
+
171
178
  padded_cache_lengths += current_padded_cache_lengths
172
179
  step += num_processed_tokens
173
180
 
174
- if not is_external_block_tables:
175
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
176
-
177
- return RBLNGemma3ForCausalLMOutput(
178
- logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
179
- )
180
-
181
- def decode_forward(
182
- self,
183
- inputs: torch.Tensor,
184
- cache_position: torch.Tensor = None,
185
- block_tables: torch.Tensor = None,
186
- is_external_block_tables: bool = None,
187
- attention_mask: Optional[torch.Tensor] = None,
188
- position_embed: Optional[torch.Tensor] = None,
189
- position_ids: Optional[torch.Tensor] = None,
190
- local_block_tables: Optional[torch.Tensor] = None,
191
- lora_int_ids: Optional[torch.Tensor] = None,
192
- ) -> torch.FloatTensor:
193
- if self.rbln_config.use_lora and lora_int_ids is None:
194
- if self.lora_int_ids is None:
195
- raise ValueError(
196
- "lora_int_id is required when using LoRA. "
197
- "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
198
- )
199
-
200
- lora_int_ids = self.lora_int_ids
201
-
202
- if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
203
- raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
204
-
205
- batch_size = inputs.shape[0]
206
- if batch_size != self.batch_size:
207
- raise RuntimeError(
208
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
209
- )
181
+ if self.rbln_config.output_hidden_states:
182
+ num_hidden_layers = len(all_hidden_states[0]) - 1
183
+ concatenated_hidden_states = ()
184
+ for l_idx in range(num_hidden_layers + 1):
185
+ l_hidden_states = torch.cat([hidden_states[l_idx] for hidden_states in all_hidden_states], dim=1)
186
+ l_hidden_states = l_hidden_states[:, :query_length, :]
187
+ concatenated_hidden_states += (l_hidden_states,)
210
188
 
211
- if batch_size != cache_position.shape[0]:
212
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
189
+ all_hidden_states = concatenated_hidden_states
213
190
 
214
- # FIXME(taehoon): how to handle pos_attn_mask with external block tables
215
- if is_external_block_tables:
216
- if attention_mask is None:
217
- raise ValueError("attention_mask should be provided with external block tables.")
218
- if local_block_tables is None:
219
- raise ValueError("local_block_tables should be provided with external block tables.")
191
+ # Aggregate output_logits
192
+ output_logits = torch.concat(output_logits, dim=-2)
193
+ if self.rbln_config.logits_to_keep > 0:
194
+ output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
220
195
  else:
221
- local_block_tables = (
222
- local_block_tables
223
- if local_block_tables is not None
224
- else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
225
- )
226
- if self.rbln_config.use_attention_mask and attention_mask is None:
227
- for b_idx in range(batch_size):
228
- decoding_step = cache_position[b_idx].item()
229
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
230
- raise ValueError(
231
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
232
- )
233
- self.dec_attn_mask[b_idx, decoding_step] = 1
234
-
235
- attention_mask = self.dec_attn_mask
236
-
237
- if self.batch_size < block_tables.shape[0]:
238
- block_tables = block_tables[: self.batch_size]
196
+ output_logits = output_logits[:, :query_length, :]
197
+ # index copy for masked output_logits
198
+ if attention_mask is not None:
199
+ new_output_logits = torch.full(
200
+ (1, attention_mask.shape[-1], output_logits.shape[-1]),
201
+ fill_value=1e-10,
202
+ dtype=output_logits.dtype,
203
+ )
204
+ mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
205
+ new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
239
206
 
240
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
241
- attention_mask = attention_mask[: self.batch_size]
207
+ output_logits = new_output_logits
242
208
 
243
- logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
209
+ if not is_external_block_tables:
210
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
244
211
 
245
- return RBLNDecoderOnlyOutput(logits=logits)
212
+ return RBLNGemma3ForCausalLMOutput(
213
+ logits=output_logits,
214
+ padded_cache_lengths=padded_cache_lengths,
215
+ attention_mask=chunked_attention_mask,
216
+ hidden_states=all_hidden_states,
217
+ )
@@ -299,28 +299,60 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
299
299
  generate_idx: Optional[torch.Tensor] = None,
300
300
  padded_cache_lengths: Optional[torch.Tensor] = None,
301
301
  position_ids: Optional[torch.Tensor] = None,
302
+ output_hidden_states: Optional[bool] = None,
302
303
  **lm_kwargs: Dict[str, Any],
303
304
  ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
305
+ output_hidden_states = (
306
+ output_hidden_states
307
+ if output_hidden_states is not None
308
+ else self.rbln_config.language_model.output_hidden_states
309
+ )
310
+ if output_hidden_states != self.rbln_config.language_model.output_hidden_states:
311
+ raise ValueError(
312
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.language_model.output_hidden_states {self.rbln_config.language_model.output_hidden_states} "
313
+ f"Please compile again with the correct argument."
314
+ )
315
+
304
316
  # prefill
305
317
  if cache_position is None:
306
318
  logits = []
307
319
  inputs_embeds = self._preprocess_prefill(input_ids, inputs_embeds, pixel_values)
308
320
  batch_size = inputs_embeds.shape[0]
309
321
 
322
+ all_hidden_states = (
323
+ tuple(
324
+ torch.zeros(
325
+ batch_size,
326
+ inputs_embeds.shape[1],
327
+ self.config.text_config.hidden_size,
328
+ dtype=self.rbln_config.torch_dtype,
329
+ )
330
+ for _ in range(self.config.text_config.num_hidden_layers + 1)
331
+ )
332
+ if self.rbln_config.language_model.output_hidden_states
333
+ else None
334
+ )
335
+
310
336
  for b_idx in range(batch_size):
311
337
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
312
338
  token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
313
339
  cache_position = self.get_padded_cache_position(cache_position, token_type_id)
314
340
 
315
- output = self.language_model.prefill_decoder(
341
+ outputs = self.language_model.prefill_decoder(
316
342
  inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
317
343
  attention_mask=attention_mask[b_idx],
318
344
  cache_position=cache_position,
319
345
  batch_idx=b_idx,
320
346
  token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
321
347
  )
322
- padded_cache_lengths[b_idx] += output.padded_cache_lengths
323
- logits.append(output.logits)
348
+ padded_cache_lengths[b_idx] += outputs.padded_cache_lengths
349
+ logits.append(outputs.logits)
350
+ if self.rbln_config.language_model.output_hidden_states:
351
+ for l_idx in range(self.config.text_config.num_hidden_layers + 1):
352
+ mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
353
+ all_hidden_states[l_idx][b_idx].index_copy_(
354
+ dim=0, index=mask_indices, source=outputs.hidden_states[l_idx][0]
355
+ )
324
356
 
325
357
  logits = torch.cat(logits, dim=0)
326
358
  # decoder
@@ -334,15 +366,20 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
334
366
  f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
335
367
  )
336
368
 
337
- logits = self.language_model.decoders[batch_size](
369
+ outputs = self.language_model.decoders[batch_size](
338
370
  input_ids=input_ids,
339
371
  inputs_embeds=inputs_embeds,
340
372
  cache_position=cache_position,
341
373
  position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
342
- ).logits
374
+ )
375
+ logits = outputs.logits
376
+ all_hidden_states = outputs.hidden_states
343
377
 
344
378
  return RBLNDecoderOnlyOutput(
345
- logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
379
+ logits=logits,
380
+ generate_idx=generate_idx,
381
+ padded_cache_lengths=padded_cache_lengths,
382
+ hidden_states=all_hidden_states,
346
383
  )
347
384
 
348
385
 
@@ -403,26 +440,6 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
403
440
  )
404
441
  return embed_tokens
405
442
 
406
- @classmethod
407
- def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
408
- sliding_window = getattr(model_config, "sliding_window", None)
409
- sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
410
- if sliding_window_pattern is None:
411
- if hasattr(model_config, "layer_types"):
412
- first_full_attention_index = model_config.layer_types.index("full_attention")
413
- sliding_window_pattern = first_full_attention_index + 1
414
- else:
415
- raise ValueError("Cannot determine sliding_window_pattern from model_config")
416
-
417
- if sliding_window_pattern <= model_config.num_hidden_layers:
418
- rbln_config.cache_impl = "hybrid"
419
- rbln_config.sliding_window = sliding_window
420
- rbln_config.sliding_window_layers = [
421
- i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
422
- ]
423
-
424
- return rbln_config
425
-
426
443
  @classmethod
427
444
  def _update_submodule_config(
428
445
  cls,
@@ -150,7 +150,7 @@ class _GroundingDinoEncoder(torch.nn.Module):
150
150
  all_attn_fused_vision = () if output_attentions else None
151
151
  all_attn_enhanced_text = () if output_attentions else None
152
152
  all_attn_deformable = () if output_attentions else None
153
- for i, encoder_layer in enumerate(self.layers):
153
+ for _, encoder_layer in enumerate(self.layers):
154
154
  if output_hidden_states:
155
155
  encoder_vision_states += (vision_features,)
156
156
  encoder_text_states += (text_features,)
@@ -304,7 +304,6 @@ class RBLNGroundingDinoForObjectDetection(RBLNModel):
304
304
  for feature_map, mask in vision_features:
305
305
  # position encoding
306
306
  position_embeddings_list.append(self.backbone_position_embedding(feature_map, mask).to(feature_map.dtype))
307
- vision_features, position_embeddings_list
308
307
 
309
308
  # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
310
309
  feature_maps = []
@@ -337,7 +337,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
337
337
  pooler_out_size = [pixel_values.shape[0], self.config.vision_config.hidden_size]
338
338
 
339
339
  vision_out_buffer = []
340
- for i in range(self.config.vision_config.num_hidden_layers + 2):
340
+ for _ in range(self.config.vision_config.num_hidden_layers + 2):
341
341
  vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
342
342
  if pooler_out_size is not None:
343
343
  vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu"))
@@ -300,7 +300,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGeneration
300
300
  ]
301
301
  pooler_out_size = [pixel_values.shape[0] * pixel_values.shape[1], self.config.vision_config.hidden_size]
302
302
  vision_out_buffer = []
303
- for i in range(self.config.vision_config.num_hidden_layers + 2):
303
+ for _ in range(self.config.vision_config.num_hidden_layers + 2):
304
304
  vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
305
305
  vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu"))
306
306
 
@@ -12,13 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from transformers import PretrainedConfig
16
15
 
17
16
  from ....utils import logging
18
17
  from ...models.decoderonly import (
19
18
  RBLNDecoderOnlyModel,
20
19
  RBLNDecoderOnlyModelForCausalLM,
21
- RBLNDecoderOnlyModelForCausalLMConfig,
22
20
  )
23
21
  from .mistral_architecture import MistralWrapper
24
22
 
@@ -85,16 +83,6 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
85
83
 
86
84
  _decoder_wrapper_cls = MistralWrapper
87
85
 
88
- @classmethod
89
- def _update_sliding_window_config(
90
- cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
91
- ):
92
- rbln_config.cache_impl = "sliding_window"
93
- rbln_config.sliding_window = model_config.sliding_window
94
- rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
95
-
96
- return rbln_config
97
-
98
86
 
99
87
  class RBLNMistralModel(RBLNDecoderOnlyModel):
100
88
  """
@@ -103,13 +91,3 @@ class RBLNMistralModel(RBLNDecoderOnlyModel):
103
91
  """
104
92
 
105
93
  _decoder_wrapper_cls = MistralWrapper
106
-
107
- @classmethod
108
- def _update_sliding_window_config(
109
- cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
110
- ):
111
- rbln_config.cache_impl = "sliding_window"
112
- rbln_config.sliding_window = model_config.sliding_window
113
- rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
114
-
115
- return rbln_config